diff --git a/.travis.yml b/.travis.yml index c8c19d980..bd3a5310c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,10 +15,10 @@ addons: matrix: include: - rust: nightly - - rust: stable +# - rust: stable before_deploy: cargo doc --no-deps - allow_failures: - - rust: nightly +# allow_failures: +# - rust: nightly before_script: - cargo clean @@ -39,8 +39,8 @@ script: # Run integration tests - cargo test -p h2-tests - # Run h2spec on stable - - if [ "${TRAVIS_RUST_VERSION}" = "stable" ]; then ./ci/h2spec.sh; fi + # Run h2spec on nightly for the time being. TODO: Change it to stable after Rust 1.38 release + - if [ "${TRAVIS_RUST_VERSION}" = "nightly" ]; then ./ci/h2spec.sh; fi # Check minimal versions - if [ "${TRAVIS_RUST_VERSION}" = "nightly" ]; then cargo clean; cargo check -Z minimal-versions; fi diff --git a/Cargo.toml b/Cargo.toml index bc60a995f..cec4cdee5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,8 +41,9 @@ members = [ ] [dependencies] -futures = "0.1" -tokio-io = "0.1.4" +futures-preview = "0.3.0-alpha.17" +tokio-io = "0.2.0-alpha.1" +tokio-codec = "0.2.0-alpha.1" bytes = "0.4.7" http = "0.1.8" log = "0.4.1" @@ -64,7 +65,7 @@ serde = "1.0.0" serde_json = "1.0.0" # Akamai example -tokio = "0.1.8" +tokio = "0.2.0-alpha.1" env_logger = { version = "0.5.3", default-features = false } rustls = "0.12" tokio-rustls = "0.5.0" diff --git a/examples/akamai.rs b/examples/akamai.rs index 0f8f7ce20..2222c97de 100644 --- a/examples/akamai.rs +++ b/examples/akamai.rs @@ -1,3 +1,10 @@ +fn main() { + // Enable the below code once tokio_rustls moves to std::future +} + +/* +#![feature(async_await)] + use h2::client; use futures::*; @@ -10,10 +17,12 @@ use tokio_rustls::ClientConfigExt; use webpki::DNSNameRef; use std::net::ToSocketAddrs; +use std::error::Error; const ALPN_H2: &str = "h2"; -pub fn main() { +#[tokio::main] +pub async fn main() -> Result<(), Box> { let _ = env_logger::try_init(); let tls_client_config = std::sync::Arc::new({ @@ -33,49 +42,30 @@ pub fn main() { println!("ADDR: {:?}", addr); - let tcp = TcpStream::connect(&addr); + let tcp = TcpStream::connect(&addr).await?; let dns_name = DNSNameRef::try_from_ascii_str("http2.akamai.com").unwrap(); - - let tcp = tcp.then(move |res| { - let tcp = res.unwrap(); - tls_client_config - .connect_async(dns_name, tcp) - .then(|res| { - let tls = res.unwrap(); - { - let (_, session) = tls.get_ref(); - let negotiated_protocol = session.get_alpn_protocol(); - assert_eq!(Some(ALPN_H2), negotiated_protocol.as_ref().map(|x| &**x)); - } - - println!("Starting client handshake"); - client::handshake(tls) - }) - .then(|res| { - let (mut client, h2) = res.unwrap(); - - let request = Request::builder() + let res = tls_client_config.connect_async(dns_name, tcp).await; + let tls = res.unwrap(); + { + let (_, session) = tls.get_ref(); + let negotiated_protocol = session.get_alpn_protocol(); + assert_eq!(Some(ALPN_H2), negotiated_protocol.as_ref().map(|x| &**x)); + } + + println!("Starting client handshake"); + let (mut client, h2) = client::handshake(tls).await?; + + let request = Request::builder() .method(Method::GET) .uri("https://http2.akamai.com/") .body(()) .unwrap(); - let (response, _) = client.send_request(request, true).unwrap(); - - let stream = response.and_then(|response| { - let (_, body) = response.into_parts(); - - body.for_each(|chunk| { - println!("RX: {:?}", chunk); - Ok(()) - }) - }); - - h2.join(stream) - }) - }) - .map_err(|e| eprintln!("ERROR: {:?}", e)) - .map(|((), ())| ()); - - tokio::run(tcp); + let (response, _) = client.send_request(request, true).unwrap(); + let (_, mut body) = response.await?.into_parts(); + while let Some(chunk) = body.next().await { + println!("RX: {:?}", chunk?); + } + Ok(()) } +*/ diff --git a/examples/client.rs b/examples/client.rs index 53a014f92..399334483 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,87 +1,56 @@ +#![feature(async_await)] + +use futures::future::poll_fn; +use futures::StreamExt; use h2::client; -use h2::RecvStream; +use http::{HeaderMap, Request}; -use futures::*; -use http::*; +use std::error::Error; use tokio::net::TcpStream; -struct Process { - body: RecvStream, - trailers: bool, -} - -impl Future for Process { - type Item = (); - type Error = h2::Error; - - fn poll(&mut self) -> Poll<(), h2::Error> { - loop { - if self.trailers { - let trailers = try_ready!(self.body.poll_trailers()); - - println!("GOT TRAILERS: {:?}", trailers); - - return Ok(().into()); - } else { - match try_ready!(self.body.poll()) { - Some(chunk) => { - println!("GOT CHUNK = {:?}", chunk); - }, - None => { - self.trailers = true; - }, - } - } - } - } -} - -pub fn main() { +#[tokio::main] +pub async fn main() -> Result<(), Box> { let _ = env_logger::try_init(); - let tcp = TcpStream::connect(&"127.0.0.1:5928".parse().unwrap()); + let tcp = TcpStream::connect(&"127.0.0.1:5928".parse().unwrap()).await?; + let (mut client, h2) = client::handshake(tcp).await?; - let tcp = tcp.then(|res| { - let tcp = res.unwrap(); - client::handshake(tcp) - }).then(|res| { - let (mut client, h2) = res.unwrap(); + println!("sending request"); - println!("sending request"); + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); - let request = Request::builder() - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); + let mut trailers = HeaderMap::new(); + trailers.insert("zomg", "hello".parse().unwrap()); - let mut trailers = HeaderMap::new(); - trailers.insert("zomg", "hello".parse().unwrap()); + let (response, mut stream) = client.send_request(request, false).unwrap(); - let (response, mut stream) = client.send_request(request, false).unwrap(); + // send trailers + stream.send_trailers(trailers).unwrap(); - // send trailers - stream.send_trailers(trailers).unwrap(); + // Spawn a task to run the conn... + tokio::spawn(async move { + if let Err(e) = h2.await { + println!("GOT ERR={:?}", e); + } + }); - // Spawn a task to run the conn... - tokio::spawn(h2.map_err(|e| println!("GOT ERR={:?}", e))); + let response = response.await?; + println!("GOT RESPONSE: {:?}", response); - response - .and_then(|response| { - println!("GOT RESPONSE: {:?}", response); + // Get the body + let (_, mut body) = response.into_parts(); - // Get the body - let (_, body) = response.into_parts(); + while let Some(chunk) = body.next().await { + println!("GOT CHUNK = {:?}", chunk?); + } - Process { - body, - trailers: false, - } - }) - .map_err(|e| { - println!("GOT ERR={:?}", e); - }) - }); + if let Some(trailers) = poll_fn(|cx| body.poll_trailers(cx)).await { + println!("GOT TRAILERS: {:?}", trailers?); + } - tokio::run(tcp); + Ok(()) } diff --git a/examples/server.rs b/examples/server.rs index 89e66591b..870e23c72 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,62 +1,50 @@ +#![feature(async_await)] + use h2::server; use bytes::*; use futures::*; -use http::*; +use http::{Response, StatusCode}; -use tokio::net::TcpListener; +use tokio::net::{TcpListener, TcpStream}; +use std::error::Error; -pub fn main() { +#[tokio::main] +pub async fn main() -> Result<(), Box> { let _ = env_logger::try_init(); let listener = TcpListener::bind(&"127.0.0.1:5928".parse().unwrap()).unwrap(); println!("listening on {:?}", listener.local_addr()); + let mut incoming = listener.incoming(); + + while let Some(socket) = incoming.next().await { + tokio::spawn(async move { + if let Err(e) = handle(socket).await { + println!(" -> err={:?}", e); + } + }); + } - let server = listener.incoming().for_each(move |socket| { - // let socket = io_dump::Dump::to_stdout(socket); - - let connection = server::handshake(socket) - .and_then(|conn| { - println!("H2 connection bound"); - - conn.for_each(|(request, mut respond)| { - println!("GOT request: {:?}", request); - - let response = Response::builder().status(StatusCode::OK).body(()).unwrap(); - - let mut send = match respond.send_response(response, false) { - Ok(send) => send, - Err(e) => { - println!(" error respond; err={:?}", e); - return Ok(()); - } - }; - - println!(">>>> sending data"); - if let Err(e) = send.send_data(Bytes::from_static(b"hello world"), true) { - println!(" -> err={:?}", e); - } - - Ok(()) - }) - }) - .and_then(|_| { - println!("~~~~~~~~~~~~~~~~~~~~~~~~~~~ H2 connection CLOSE !!!!!! ~~~~~~~~~~~"); - Ok(()) - }) - .then(|res| { - if let Err(e) = res { - println!(" -> err={:?}", e); - } - - Ok(()) - }); - - tokio::spawn(Box::new(connection)); - Ok(()) - }) - .map_err(|e| eprintln!("accept error: {}", e)); - - tokio::run(server); + Ok(()) } + +async fn handle(socket: io::Result) -> Result<(), Box> { + let mut connection = server::handshake(socket?).await?; + println!("H2 connection bound"); + + while let Some(result) = connection.next().await { + let (request, mut respond) = result?; + println!("GOT request: {:?}", request); + let response = Response::builder().status(StatusCode::OK).body(()).unwrap(); + + let mut send = respond.send_response(response, false)?; + + println!(">>>> sending data"); + send.send_data(Bytes::from_static(b"hello world"), true)?; + } + + println!("~~~~~~~~~~~~~~~~~~~~~~~~~~~ H2 connection CLOSE !!!!!! ~~~~~~~~~~~"); + + Ok(()) +} \ No newline at end of file diff --git a/src/client.rs b/src/client.rs index 17375b32a..70b84fe2c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -64,72 +64,60 @@ //! //! # Example //! -//! ```rust +//! ```rust, no_run +//! #![feature(async_await)] +//! //! use h2::client; //! //! use futures::*; -//! use http::*; -//! +//! use http::{Request, Method}; +//! use std::error::Error; //! use tokio::net::TcpStream; //! -//! pub fn main() { +//! #[tokio::main] +//! pub async fn main() -> Result<(), Box> { //! let addr = "127.0.0.1:5928".parse().unwrap(); +//! +//! // Establish TCP connection to the server. +//! let tcp = TcpStream::connect(&addr).await?; +//! let (h2, connection) = client::handshake(tcp).await?; +//! tokio::spawn(async move { +//! connection.await.unwrap(); +//! }); //! -//! tokio::run( -//! // Establish TCP connection to the server. -//! TcpStream::connect(&addr) -//! .map_err(|_| { -//! panic!("failed to establish TCP connection") -//! }) -//! .and_then(|tcp| client::handshake(tcp)) -//! .and_then(|(h2, connection)| { -//! let connection = connection -//! .map_err(|_| panic!("HTTP/2.0 connection failed")); -//! -//! // Spawn a new task to drive the connection state -//! tokio::spawn(connection); -//! -//! // Wait until the `SendRequest` handle has available -//! // capacity. -//! h2.ready() -//! }) -//! .and_then(|mut h2| { -//! // Prepare the HTTP request to send to the server. -//! let request = Request::builder() +//! let mut h2 = h2.ready().await?; +//! // Prepare the HTTP request to send to the server. +//! let request = Request::builder() //! .method(Method::GET) //! .uri("https://www.example.com/") //! .body(()) //! .unwrap(); //! -//! // Send the request. The second tuple item allows the caller -//! // to stream a request body. -//! let (response, _) = h2.send_request(request, true).unwrap(); +//! // Send the request. The second tuple item allows the caller +//! // to stream a request body. +//! let (response, _) = h2.send_request(request, true).unwrap(); +//! +//! let (head, mut body) = response.await?.into_parts(); //! -//! response.and_then(|response| { -//! let (head, mut body) = response.into_parts(); +//! println!("Received response: {:?}", head); //! -//! println!("Received response: {:?}", head); +//! // The `release_capacity` handle allows the caller to manage +//! // flow control. +//! // +//! // Whenever data is received, the caller is responsible for +//! // releasing capacity back to the server once it has freed +//! // the data from memory. +//! let mut release_capacity = body.release_capacity().clone(); //! -//! // The `release_capacity` handle allows the caller to manage -//! // flow control. -//! // -//! // Whenever data is received, the caller is responsible for -//! // releasing capacity back to the server once it has freed -//! // the data from memory. -//! let mut release_capacity = body.release_capacity().clone(); +//! while let Some(chunk) = body.next().await { +//! let chunk = chunk?; +//! println!("RX: {:?}", chunk); //! -//! body.for_each(move |chunk| { -//! println!("RX: {:?}", chunk); +//! // Let the server send more data. +//! let _ = release_capacity.release_capacity(chunk.len()); +//! } //! -//! // Let the server send more data. -//! let _ = release_capacity.release_capacity(chunk.len()); -//! -//! Ok(()) -//! }) -//! }) -//! }) -//! .map_err(|e| panic!("failed to perform HTTP/2.0 request: {:?}", e)) -//! ) +//! Ok(()) //! } //! ``` //! @@ -151,42 +139,21 @@ //! [`Builder`]: struct.Builder.html //! [`Error`]: ../struct.Error.html -use crate::{SendStream, RecvStream, ReleaseCapacity, PingPong}; use crate::codec::{Codec, RecvError, SendError, UserError}; use crate::frame::{Headers, Pseudo, Reason, Settings, StreamId}; use crate::proto; +use crate::{PingPong, RecvStream, ReleaseCapacity, SendStream}; use bytes::{Bytes, IntoBuf}; -use futures::{Async, Future, Poll, Stream, try_ready}; -use http::{uri, HeaderMap, Request, Response, Method, Version}; -use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_io::io::WriteAll; - +use futures::{ready, FutureExt, Stream}; +use http::{uri, HeaderMap, Method, Request, Response, Version}; use std::fmt; -use std::marker::PhantomData; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Duration; use std::usize; - -/// Performs the HTTP/2.0 connection handshake. -/// -/// This type implements `Future`, yielding a `(SendRequest, Connection)` -/// instance once the handshake has completed. -/// -/// The handshake is completed once both the connection preface and the initial -/// settings frame is sent by the client. -/// -/// The handshake future does not wait for the initial settings frame from the -/// server. -/// -/// See [module] level documentation for more details. -/// -/// [module]: index.html -#[must_use = "futures do nothing unless polled"] -pub struct Handshake { - builder: Builder, - inner: WriteAll, - _marker: PhantomData, -} +use tokio_io::{AsyncRead, AsyncWrite, AsyncWriteExt}; /// Initializes new HTTP/2.0 streams on a connection by sending a request. /// @@ -246,31 +213,21 @@ pub struct ReadySendRequest { /// # Examples /// /// ``` -/// # use futures::{Future, Stream}; -/// # use futures::future::Executor; +/// # #![feature(async_await)] /// # use tokio_io::*; /// # use h2::client; /// # use h2::client::*; /// # -/// # fn doc(my_io: T, my_executor: E) -/// # where T: AsyncRead + AsyncWrite + 'static, -/// # E: Executor>>, +/// # async fn doc(my_io: T) -> Result<(), h2::Error> +/// # where T: AsyncRead + AsyncWrite + Send + Unpin + 'static, /// # { -/// client::handshake(my_io) -/// .and_then(|(send_request, connection)| { -/// // Submit the connection handle to an executor. -/// my_executor.execute( -/// # Box::new( -/// connection.map_err(|_| panic!("connection failed")) -/// # ) -/// ).unwrap(); +/// let (send_request, connection) = client::handshake(my_io).await?; +/// // Submit the connection handle to an executor. +/// tokio::spawn(async { connection.await.expect("connection failed"); }); /// -/// // Now, use `send_request` to initialize HTTP/2.0 streams. -/// // ... -/// # drop(send_request); -/// # Ok(()) -/// }) -/// # .wait().unwrap(); +/// // Now, use `send_request` to initialize HTTP/2.0 streams. +/// // ... +/// # Ok(()) /// # } /// # /// # pub fn main() {} @@ -335,11 +292,13 @@ pub struct PushPromises { /// # Examples /// /// ``` +/// # #![feature(async_await)] /// # use tokio_io::*; /// # use h2::client::*; +/// # use bytes::Bytes; /// # -/// # fn doc(my_io: T) -/// # -> Handshake +/// # async fn doc(my_io: T) +/// -> Result<((SendRequest, Connection)), h2::Error> /// # { /// // `client_fut` is a future representing the completion of the HTTP/2.0 /// // handshake. @@ -347,7 +306,7 @@ pub struct PushPromises { /// .initial_window_size(1_000_000) /// .max_concurrent_streams(1000) /// .handshake(my_io); -/// # client_fut +/// # client_fut.await /// # } /// # /// # pub fn main() {} @@ -384,23 +343,23 @@ pub(crate) struct Peer; impl SendRequest where - B: IntoBuf, - B::Buf: 'static, + B: IntoBuf + Unpin, + B::Buf: Unpin + 'static, { /// Returns `Ready` when the connection can initialize a new HTTP/2.0 /// stream. /// /// This function must return `Ready` before `send_request` is called. When - /// `NotReady` is returned, the task will be notified once the readiness + /// `Poll::Pending` is returned, the task will be notified once the readiness /// state changes. /// /// See [module] level docs for more details. /// /// [module]: index.html - pub fn poll_ready(&mut self) -> Poll<(), crate::Error> { - try_ready!(self.inner.poll_pending_open(self.pending.as_ref())); + pub fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + ready!(self.inner.poll_pending_open(cx, self.pending.as_ref()))?; self.pending = None; - Ok(().into()) + Poll::Ready(Ok(())) } /// Consumes `self`, returning a future that returns `self` back once it is @@ -415,19 +374,15 @@ where /// # Examples /// /// ```rust - /// # use futures::*; + /// # #![feature(async_await)] /// # use h2::client::*; /// # use http::*; - /// # fn doc(send_request: SendRequest<&'static [u8]>) + /// # async fn doc(send_request: SendRequest<&'static [u8]>) /// # { /// // First, wait until the `send_request` handle is ready to send a new /// // request - /// send_request.ready() - /// .and_then(|mut send_request| { - /// // Use `send_request` here. - /// # Ok(()) - /// }) - /// # .wait().unwrap(); + /// let mut send_request = send_request.ready().await.unwrap(); + /// // Use `send_request` here. /// # } /// # pub fn main() {} /// ``` @@ -479,32 +434,24 @@ where /// Sending a request with no body /// /// ```rust - /// # use futures::*; + /// # #![feature(async_await)] /// # use h2::client::*; /// # use http::*; - /// # fn doc(send_request: SendRequest<&'static [u8]>) + /// # async fn doc(send_request: SendRequest<&'static [u8]>) /// # { /// // First, wait until the `send_request` handle is ready to send a new /// // request - /// send_request.ready() - /// .and_then(|mut send_request| { - /// // Prepare the HTTP request to send to the server. - /// let request = Request::get("https://www.example.com/") - /// .body(()) - /// .unwrap(); - /// - /// // Send the request to the server. Since we are not sending a - /// // body or trailers, we can drop the `SendStream` instance. - /// let (response, _) = send_request - /// .send_request(request, true).unwrap(); - /// - /// response - /// }) - /// .and_then(|response| { - /// // Process the response - /// # Ok(()) - /// }) - /// # .wait().unwrap(); + /// let mut send_request = send_request.ready().await.unwrap(); + /// // Prepare the HTTP request to send to the server. + /// let request = Request::get("https://www.example.com/") + /// .body(()) + /// .unwrap(); + /// + /// // Send the request to the server. Since we are not sending a + /// // body or trailers, we can drop the `SendStream` instance. + /// let (response, _) = send_request.send_request(request, true).unwrap(); + /// let response = response.await.unwrap(); + /// // Process the response /// # } /// # pub fn main() {} /// ``` @@ -512,48 +459,43 @@ where /// Sending a request with a body and trailers /// /// ```rust - /// # use futures::*; + /// # #![feature(async_await)] /// # use h2::client::*; /// # use http::*; - /// # fn doc(send_request: SendRequest<&'static [u8]>) + /// # async fn doc(send_request: SendRequest<&'static [u8]>) /// # { /// // First, wait until the `send_request` handle is ready to send a new /// // request - /// send_request.ready() - /// .and_then(|mut send_request| { - /// // Prepare the HTTP request to send to the server. - /// let request = Request::get("https://www.example.com/") - /// .body(()) - /// .unwrap(); - /// - /// // Send the request to the server. If we are not sending a - /// // body or trailers, we can drop the `SendStream` instance. - /// let (response, mut send_stream) = send_request - /// .send_request(request, false).unwrap(); - /// - /// // At this point, one option would be to wait for send capacity. - /// // Doing so would allow us to not hold data in memory that - /// // cannot be sent. However, this is not a requirement, so this - /// // example will skip that step. See `SendStream` documentation - /// // for more details. - /// send_stream.send_data(b"hello", false).unwrap(); - /// send_stream.send_data(b"world", false).unwrap(); - /// - /// // Send the trailers. - /// let mut trailers = HeaderMap::new(); - /// trailers.insert( - /// header::HeaderName::from_bytes(b"my-trailer").unwrap(), - /// header::HeaderValue::from_bytes(b"hello").unwrap()); - /// - /// send_stream.send_trailers(trailers).unwrap(); - /// - /// response - /// }) - /// .and_then(|response| { - /// // Process the response - /// # Ok(()) - /// }) - /// # .wait().unwrap(); + /// let mut send_request = send_request.ready().await.unwrap(); + /// + /// // Prepare the HTTP request to send to the server. + /// let request = Request::get("https://www.example.com/") + /// .body(()) + /// .unwrap(); + /// + /// // Send the request to the server. If we are not sending a + /// // body or trailers, we can drop the `SendStream` instance. + /// let (response, mut send_stream) = send_request + /// .send_request(request, false).unwrap(); + /// + /// // At this point, one option would be to wait for send capacity. + /// // Doing so would allow us to not hold data in memory that + /// // cannot be sent. However, this is not a requirement, so this + /// // example will skip that step. See `SendStream` documentation + /// // for more details. + /// send_stream.send_data(b"hello", false).unwrap(); + /// send_stream.send_data(b"world", false).unwrap(); + /// + /// // Send the trailers. + /// let mut trailers = HeaderMap::new(); + /// trailers.insert( + /// header::HeaderName::from_bytes(b"my-trailer").unwrap(), + /// header::HeaderValue::from_bytes(b"hello").unwrap()); + /// + /// send_stream.send_trailers(trailers).unwrap(); + /// + /// let response = response.await.unwrap(); + /// // Process the response /// # } /// # pub fn main() {} /// ``` @@ -634,21 +576,21 @@ where // ===== impl ReadySendRequest ===== impl Future for ReadySendRequest -where B: IntoBuf, - B::Buf: 'static, +where + B: IntoBuf + Unpin, + B::Buf: Unpin + 'static, { - type Item = SendRequest; - type Error = crate::Error; + type Output = Result, crate::Error>; - fn poll(&mut self) -> Poll { - match self.inner { - Some(ref mut send_request) => { - let _ = try_ready!(send_request.poll_ready()); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match &mut self.inner { + Some(send_request) => { + let _ = ready!(send_request.poll_ready(cx))?; } None => panic!("called `poll` after future completed"), } - Ok(self.inner.take().unwrap().into()) + Poll::Ready(Ok(self.inner.take().unwrap())) } } @@ -663,11 +605,13 @@ impl Builder { /// # Examples /// /// ``` + /// # #![feature(async_await)] /// # use tokio_io::*; /// # use h2::client::*; + /// # use bytes::Bytes; /// # - /// # fn doc(my_io: T) - /// # -> Handshake + /// # async fn doc(my_io: T) + /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { /// // `client_fut` is a future representing the completion of the HTTP/2.0 /// // handshake. @@ -675,7 +619,7 @@ impl Builder { /// .initial_window_size(1_000_000) /// .max_concurrent_streams(1000) /// .handshake(my_io); - /// # client_fut + /// # client_fut.await /// # } /// # /// # pub fn main() {} @@ -704,18 +648,20 @@ impl Builder { /// # Examples /// /// ``` + /// # #![feature(async_await)] /// # use tokio_io::*; /// # use h2::client::*; + /// # use bytes::Bytes; /// # - /// # fn doc(my_io: T) - /// # -> Handshake + /// # async fn doc(my_io: T) + /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { /// // `client_fut` is a future representing the completion of the HTTP/2.0 /// // handshake. /// let client_fut = Builder::new() /// .initial_window_size(1_000_000) /// .handshake(my_io); - /// # client_fut + /// # client_fut.await /// # } /// # /// # pub fn main() {} @@ -738,18 +684,20 @@ impl Builder { /// # Examples /// /// ``` + /// # #![feature(async_await)] /// # use tokio_io::*; /// # use h2::client::*; + /// # use bytes::Bytes; /// # - /// # fn doc(my_io: T) - /// # -> Handshake + /// # async fn doc(my_io: T) + /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { /// // `client_fut` is a future representing the completion of the HTTP/2.0 /// // handshake. /// let client_fut = Builder::new() /// .initial_connection_window_size(1_000_000) /// .handshake(my_io); - /// # client_fut + /// # client_fut.await /// # } /// # /// # pub fn main() {} @@ -771,18 +719,20 @@ impl Builder { /// # Examples /// /// ``` + /// # #![feature(async_await)] /// # use tokio_io::*; /// # use h2::client::*; + /// # use bytes::Bytes; /// # - /// # fn doc(my_io: T) - /// # -> Handshake + /// # async fn doc(my_io: T) + /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { /// // `client_fut` is a future representing the completion of the HTTP/2.0 /// // handshake. /// let client_fut = Builder::new() /// .max_frame_size(1_000_000) /// .handshake(my_io); - /// # client_fut + /// # client_fut.await /// # } /// # /// # pub fn main() {} @@ -810,18 +760,20 @@ impl Builder { /// # Examples /// /// ``` + /// # #![feature(async_await)] /// # use tokio_io::*; /// # use h2::client::*; + /// # use bytes::Bytes; /// # - /// # fn doc(my_io: T) - /// # -> Handshake + /// # async fn doc(my_io: T) + /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { /// // `client_fut` is a future representing the completion of the HTTP/2.0 /// // handshake. /// let client_fut = Builder::new() /// .max_header_list_size(16 * 1024) /// .handshake(my_io); - /// # client_fut + /// # client_fut.await /// # } /// # /// # pub fn main() {} @@ -858,18 +810,20 @@ impl Builder { /// # Examples /// /// ``` + /// # #![feature(async_await)] /// # use tokio_io::*; /// # use h2::client::*; + /// # use bytes::Bytes; /// # - /// # fn doc(my_io: T) - /// # -> Handshake + /// # async fn doc(my_io: T) + /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { /// // `client_fut` is a future representing the completion of the HTTP/2.0 /// // handshake. /// let client_fut = Builder::new() /// .max_concurrent_streams(1000) /// .handshake(my_io); - /// # client_fut + /// # client_fut.await /// # } /// # /// # pub fn main() {} @@ -898,18 +852,20 @@ impl Builder { /// # Examples /// /// ``` + /// # #![feature(async_await)] /// # use tokio_io::*; /// # use h2::client::*; + /// # use bytes::Bytes; /// # - /// # fn doc(my_io: T) - /// # -> Handshake + /// # async fn doc(my_io: T) + /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { /// // `client_fut` is a future representing the completion of the HTTP/2.0 /// // handshake. /// let client_fut = Builder::new() /// .initial_max_send_streams(1000) /// .handshake(my_io); - /// # client_fut + /// # client_fut.await /// # } /// # /// # pub fn main() {} @@ -942,18 +898,20 @@ impl Builder { /// # Examples /// /// ``` + /// # #![feature(async_await)] /// # use tokio_io::*; /// # use h2::client::*; + /// # use bytes::Bytes; /// # - /// # fn doc(my_io: T) - /// # -> Handshake + /// # async fn doc(my_io: T) + /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { /// // `client_fut` is a future representing the completion of the HTTP/2.0 /// // handshake. /// let client_fut = Builder::new() /// .max_concurrent_reset_streams(1000) /// .handshake(my_io); - /// # client_fut + /// # client_fut.await /// # } /// # /// # pub fn main() {} @@ -986,19 +944,21 @@ impl Builder { /// # Examples /// /// ``` + /// # #![feature(async_await)] /// # use tokio_io::*; /// # use h2::client::*; /// # use std::time::Duration; + /// # use bytes::Bytes; /// # - /// # fn doc(my_io: T) - /// # -> Handshake + /// # async fn doc(my_io: T) + /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { /// // `client_fut` is a future representing the completion of the HTTP/2.0 /// // handshake. /// let client_fut = Builder::new() /// .reset_stream_duration(Duration::from_secs(10)) /// .handshake(my_io); - /// # client_fut + /// # client_fut.await /// # } /// # /// # pub fn main() {} @@ -1023,19 +983,21 @@ impl Builder { /// # Examples /// /// ``` + /// # #![feature(async_await)] /// # use tokio_io::*; /// # use h2::client::*; /// # use std::time::Duration; + /// # use bytes::Bytes; /// # - /// # fn doc(my_io: T) - /// # -> Handshake + /// # async fn doc(my_io: T) + /// # -> Result<((SendRequest, Connection)), h2::Error> /// # { /// // `client_fut` is a future representing the completion of the HTTP/2.0 /// // handshake. /// let client_fut = Builder::new() /// .enable_push(false) /// .handshake(my_io); - /// # client_fut + /// # client_fut.await /// # } /// # /// # pub fn main() {} @@ -1059,7 +1021,11 @@ impl Builder { /// Creates a new configured HTTP/2.0 client backed by `io`. /// /// It is expected that `io` already be in an appropriate state to commence - /// the [HTTP/2.0 handshake]. See [Handshake] for more details. + /// the [HTTP/2.0 handshake]. The handshake is completed once both the connection + /// preface and the initial settings frame is sent by the client. + /// + /// The handshake future does not wait for the initial settings frame from the + /// server. /// /// Returns a future which resolves to the [`Connection`] / [`SendRequest`] /// tuple once the HTTP/2.0 handshake has been completed. @@ -1068,7 +1034,6 @@ impl Builder { /// type. See [Outbound data type] for more details. /// /// [HTTP/2.0 handshake]: http://httpwg.org/specs/rfc7540.html#ConnectionHeader - /// [Handshake]: ../index.html#handshake /// [`Connection`]: struct.Connection.html /// [`SendRequest`]: struct.SendRequest.html /// [Outbound data type]: ../index.html#outbound-data-type. @@ -1078,17 +1043,19 @@ impl Builder { /// Basic usage: /// /// ``` + /// # #![feature(async_await)] /// # use tokio_io::*; /// # use h2::client::*; + /// # use bytes::Bytes; /// # - /// # fn doc(my_io: T) - /// # -> Handshake + /// # async fn doc(my_io: T) + /// -> Result<((SendRequest, Connection)), h2::Error> /// # { /// // `client_fut` is a future representing the completion of the HTTP/2.0 /// // handshake. /// let client_fut = Builder::new() /// .handshake(my_io); - /// # client_fut + /// # client_fut.await /// # } /// # /// # pub fn main() {} @@ -1098,26 +1065,30 @@ impl Builder { /// type will be `&'static [u8]`. /// /// ``` + /// # #![feature(async_await)] /// # use tokio_io::*; /// # use h2::client::*; /// # - /// # fn doc(my_io: T) - /// # -> Handshake + /// # async fn doc(my_io: T) + /// # -> Result<((SendRequest<&'static [u8]>, Connection)), h2::Error> /// # { /// // `client_fut` is a future representing the completion of the HTTP/2.0 /// // handshake. - /// let client_fut: Handshake<_, &'static [u8]> = Builder::new() - /// .handshake(my_io); - /// # client_fut + /// let client_fut = Builder::new() + /// .handshake::<_, &'static [u8]>(my_io); + /// # client_fut.await /// # } /// # /// # pub fn main() {} /// ``` - pub fn handshake(&self, io: T) -> Handshake + pub fn handshake( + &self, + io: T, + ) -> impl Future, Connection), crate::Error>> where - T: AsyncRead + AsyncWrite, - B: IntoBuf, - B::Buf: 'static, + T: AsyncRead + AsyncWrite + Unpin, + B: IntoBuf + Unpin, + B::Buf: Unpin + 'static, { Connection::handshake2(io, self.clone()) } @@ -1149,51 +1120,86 @@ impl Default for Builder { /// # Examples /// /// ``` -/// # use futures::*; +/// # #![feature(async_await)] /// # use tokio_io::*; /// # use h2::client; /// # use h2::client::*; /// # -/// # fn doc(my_io: T) +/// # async fn doc(my_io: T) -> Result<(), h2::Error> /// # { -/// client::handshake(my_io) -/// .and_then(|(send_request, connection)| { -/// // The HTTP/2.0 handshake has completed, now start polling -/// // `connection` and use `send_request` to send requests to the -/// // server. -/// # Ok(()) -/// }) -/// # .wait().unwrap(); +/// let (send_request, connection) = client::handshake(my_io).await?; +/// // The HTTP/2.0 handshake has completed, now start polling +/// // `connection` and use `send_request` to send requests to the +/// // server. +/// # Ok(()) /// # } /// # /// # pub fn main() {} /// ``` -pub fn handshake(io: T) -> Handshake -where T: AsyncRead + AsyncWrite, +pub async fn handshake(io: T) -> Result<(SendRequest, Connection), crate::Error> +where + T: AsyncRead + AsyncWrite + Unpin, { - Builder::new().handshake(io) + let builder = Builder::new(); + builder.handshake(io).await } // ===== impl Connection ===== impl Connection where - T: AsyncRead + AsyncWrite, - B: IntoBuf, + T: AsyncRead + AsyncWrite + Unpin, + B: IntoBuf + Unpin, + B::Buf: Unpin, { - fn handshake2(io: T, builder: Builder) -> Handshake { - use tokio_io::io; - + async fn handshake2( + mut io: T, + builder: Builder, + ) -> Result<(SendRequest, Connection), crate::Error> { log::debug!("binding client connection"); let msg: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; - let handshake = io::write_all(io, msg); + io.write_all(msg).await?; + + log::debug!("client connection bound"); + + // Create the codec + let mut codec = Codec::new(io); + + if let Some(max) = builder.settings.max_frame_size() { + codec.set_max_recv_frame_size(max as usize); + } + + if let Some(max) = builder.settings.max_header_list_size() { + codec.set_max_recv_header_list_size(max as usize); + } + + // Send initial settings frame + codec + .buffer(builder.settings.clone().into()) + .expect("invalid SETTINGS frame"); - Handshake { - builder, - inner: handshake, - _marker: PhantomData, + let inner = proto::Connection::new( + codec, + proto::Config { + next_stream_id: builder.stream_id, + initial_max_send_streams: builder.initial_max_send_streams, + reset_stream_duration: builder.reset_stream_duration, + reset_stream_max: builder.reset_stream_max, + settings: builder.settings.clone(), + }, + ); + let send_request = SendRequest { + inner: inner.streams().clone(), + pending: None, + }; + + let mut connection = Connection { inner }; + if let Some(sz) = builder.initial_target_connection_window_size { + connection.set_target_window_size(sz); } + + Ok((send_request, connection)) } /// Sets the target window size for the whole connection. @@ -1224,23 +1230,21 @@ where /// /// This may only be called once. Calling multiple times will return `None`. pub fn ping_pong(&mut self) -> Option { - self.inner - .take_user_pings() - .map(PingPong::new) + self.inner.take_user_pings().map(PingPong::new) } } impl Future for Connection where - T: AsyncRead + AsyncWrite, - B: IntoBuf, + T: AsyncRead + AsyncWrite + Unpin, + B: IntoBuf + Unpin, + B::Buf: Unpin, { - type Item = (); - type Error = crate::Error; + type Output = Result<(), crate::Error>; - fn poll(&mut self) -> Poll<(), crate::Error> { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.inner.maybe_close_connection_if_no_streams(); - self.inner.poll().map_err(Into::into) + self.inner.poll(cx).map_err(Into::into) } } @@ -1256,85 +1260,16 @@ where } } -// ===== impl Handshake ===== - -impl Future for Handshake -where - T: AsyncRead + AsyncWrite, - B: IntoBuf, - B::Buf: 'static, -{ - type Item = (SendRequest, Connection); - type Error = crate::Error; - - fn poll(&mut self) -> Poll { - let res = self.inner.poll() - .map_err(crate::Error::from); - - let (io, _) = try_ready!(res); - - log::debug!("client connection bound"); - - // Create the codec - let mut codec = Codec::new(io); - - if let Some(max) = self.builder.settings.max_frame_size() { - codec.set_max_recv_frame_size(max as usize); - } - - if let Some(max) = self.builder.settings.max_header_list_size() { - codec.set_max_recv_header_list_size(max as usize); - } - - // Send initial settings frame - codec - .buffer(self.builder.settings.clone().into()) - .expect("invalid SETTINGS frame"); - - let inner = proto::Connection::new(codec, proto::Config { - next_stream_id: self.builder.stream_id, - initial_max_send_streams: self.builder.initial_max_send_streams, - reset_stream_duration: self.builder.reset_stream_duration, - reset_stream_max: self.builder.reset_stream_max, - settings: self.builder.settings.clone(), - }); - let send_request = SendRequest { - inner: inner.streams().clone(), - pending: None, - }; - - let mut connection = Connection { inner }; - if let Some(sz) = self.builder.initial_target_connection_window_size { - connection.set_target_window_size(sz); - } - - Ok(Async::Ready((send_request, connection))) - } -} - -impl fmt::Debug for Handshake -where - T: AsyncRead + AsyncWrite, - T: fmt::Debug, - B: fmt::Debug + IntoBuf, - B::Buf: fmt::Debug + IntoBuf, -{ - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!(fmt, "client::Handshake") - } -} - // ===== impl ResponseFuture ===== impl Future for ResponseFuture { - type Item = Response; - type Error = crate::Error; + type Output = Result, crate::Error>; - fn poll(&mut self) -> Poll { - let (parts, _) = try_ready!(self.inner.poll_response()).into_parts(); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let (parts, _) = ready!(self.inner.poll_response(cx))?.into_parts(); let body = RecvStream::new(ReleaseCapacity::new(self.inner.clone())); - Ok(Response::from_parts(parts, body).into()) + Poll::Ready(Ok(Response::from_parts(parts, body).into())) } } @@ -1358,27 +1293,31 @@ impl ResponseFuture { panic!("Reference to push promises stream taken!"); } self.push_promise_consumed = true; - PushPromises { inner: self.inner.clone() } + PushPromises { + inner: self.inner.clone(), + } } } // ===== impl PushPromises ===== impl Stream for PushPromises { - type Item = PushPromise; - type Error = crate::Error; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - match try_ready!(self.inner.poll_pushed()) { - Some((request, response)) => { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.inner.poll_pushed(cx) { + Poll::Ready(Some(Ok((request, response)))) => { let response = PushedResponseFuture { inner: ResponseFuture { - inner: response, push_promise_consumed: false - } + inner: response, + push_promise_consumed: false, + }, }; - Ok(Async::Ready(Some(PushPromise{request, response}))) + Poll::Ready(Some(Ok(PushPromise { request, response }))) } - None => Ok(Async::Ready(None)), + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, } } } @@ -1406,11 +1345,10 @@ impl PushPromise { // ===== impl PushedResponseFuture ===== impl Future for PushedResponseFuture { - type Item = Response; - type Error = crate::Error; + type Output = Result, crate::Error>; - fn poll(&mut self) -> Poll { - self.inner.poll() + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.inner.poll_unpin(cx) } } @@ -1431,8 +1369,8 @@ impl Peer { pub fn convert_send_message( id: StreamId, request: Request<()>, - end_of_stream: bool) -> Result - { + end_of_stream: bool, + ) -> Result { use http::request::Parts; let ( @@ -1503,7 +1441,9 @@ impl proto::Peer for Peer { } fn convert_poll_message( - pseudo: Pseudo, fields: HeaderMap, stream_id: StreamId + pseudo: Pseudo, + fields: HeaderMap, + stream_id: StreamId, ) -> Result { let mut b = Response::builder(); @@ -1522,7 +1462,7 @@ impl proto::Peer for Peer { id: stream_id, reason: Reason::PROTOCOL_ERROR, }); - }, + } }; *response.headers_mut() = fields; diff --git a/src/codec/framed_read.rs b/src/codec/framed_read.rs index 4a6bc4377..59d08a81d 100644 --- a/src/codec/framed_read.rs +++ b/src/codec/framed_read.rs @@ -1,24 +1,29 @@ use crate::codec::RecvError; use crate::frame::{self, Frame, Kind, Reason}; -use crate::frame::{DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE}; +use crate::frame::{ + DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE, +}; use crate::hpack; -use futures::*; +use futures::{ready, Stream}; use bytes::BytesMut; use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio_codec::{LengthDelimitedCodec, LengthDelimitedCodecError}; +use tokio_codec::FramedRead as InnerFramedRead; use tokio_io::AsyncRead; -use tokio_io::codec::length_delimited; // 16 MB "sane default" taken from golang http2 const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20; #[derive(Debug)] pub struct FramedRead { - inner: length_delimited::FramedRead, + inner: InnerFramedRead, // hpack decoder state hpack: hpack::Decoder, @@ -45,7 +50,7 @@ enum Continuable { } impl FramedRead { - pub fn new(inner: length_delimited::FramedRead) -> FramedRead { + pub fn new(inner: InnerFramedRead) -> FramedRead { FramedRead { inner: inner, hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE), @@ -138,24 +143,27 @@ impl FramedRead { res.map_err(|e| { proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e); Connection(Reason::PROTOCOL_ERROR) - })?.into() - }, + })? + .into() + } Kind::Ping => { let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]); res.map_err(|e| { proto_err!(conn: "failed to load PING frame; err={:?}", e); Connection(Reason::PROTOCOL_ERROR) - })?.into() - }, + })? + .into() + } Kind::WindowUpdate => { let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]); res.map_err(|e| { proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e); Connection(Reason::PROTOCOL_ERROR) - })?.into() - }, + })? + .into() + } Kind::Data => { let _ = bytes.split_to(frame::HEADER_LEN); let res = frame::Data::load(head, bytes.freeze()); @@ -164,28 +172,27 @@ impl FramedRead { res.map_err(|e| { proto_err!(conn: "failed to load DATA frame; err={:?}", e); Connection(Reason::PROTOCOL_ERROR) - })?.into() - }, - Kind::Headers => { - header_block!(Headers, head, bytes) - }, + })? + .into() + } + Kind::Headers => header_block!(Headers, head, bytes), Kind::Reset => { let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]); res.map_err(|e| { proto_err!(conn: "failed to load RESET frame; err={:?}", e); Connection(Reason::PROTOCOL_ERROR) - })?.into() - }, + })? + .into() + } Kind::GoAway => { let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]); res.map_err(|e| { proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e); Connection(Reason::PROTOCOL_ERROR) - })?.into() - }, - Kind::PushPromise => { - header_block!(PushPromise, head, bytes) - }, + })? + .into() + } + Kind::PushPromise => header_block!(PushPromise, head, bytes), Kind::Priority => { if head.stream_id() == 0 { // Invalid stream identifier @@ -205,13 +212,13 @@ impl FramedRead { id, reason: Reason::PROTOCOL_ERROR, }); - }, + } Err(e) => { proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e); return Err(Connection(Reason::PROTOCOL_ERROR)); } } - }, + } Kind::Continuation => { let is_end_headers = (head.flag() & 0x4) == 0x4; @@ -229,8 +236,6 @@ impl FramedRead { return Err(Connection(Reason::PROTOCOL_ERROR)); } - - // Extend the buf if partial.buf.is_empty() { partial.buf = bytes.split_off(frame::HEADER_LEN); @@ -257,9 +262,14 @@ impl FramedRead { partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]); } - match partial.frame.load_hpack(&mut partial.buf, self.max_header_list_size, &mut self.hpack) { - Ok(_) => {}, - Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}, + match partial.frame.load_hpack( + &mut partial.buf, + self.max_header_list_size, + &mut self.hpack, + ) { + Ok(_) => {} + Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) + if !is_end_headers => {} Err(frame::Error::MalformedMessage) => { let id = head.stream_id(); proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id); @@ -267,11 +277,11 @@ impl FramedRead { id, reason: Reason::PROTOCOL_ERROR, }); - }, + } Err(e) => { proto_err!(conn: "failed HPACK decoding; err={:?}", e); return Err(Connection(Reason::PROTOCOL_ERROR)); - }, + } } if is_end_headers { @@ -280,11 +290,11 @@ impl FramedRead { self.partial = Some(partial); return Ok(None); } - }, + } Kind::Unknown => { // Unknown frames are ignored return Ok(None); - }, + } }; Ok(Some(frame)) @@ -302,7 +312,7 @@ impl FramedRead { #[cfg(feature = "unstable")] #[inline] pub fn max_frame_size(&self) -> usize { - self.inner.max_frame_length() + self.inner.decoder().max_frame_length() } /// Updates the max frame size setting. @@ -311,7 +321,7 @@ impl FramedRead { #[inline] pub fn set_max_frame_size(&mut self, val: usize) { assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize); - self.inner.set_max_frame_length(val) + self.inner.decoder_mut().set_max_frame_length(val) } /// Update the max header list size setting. @@ -323,34 +333,32 @@ impl FramedRead { impl Stream for FramedRead where - T: AsyncRead, + T: AsyncRead + Unpin, { - type Item = Frame; - type Error = RecvError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { log::trace!("poll"); - let bytes = match try_ready!(self.inner.poll().map_err(map_err)) { - Some(bytes) => bytes, - None => return Ok(Async::Ready(None)), + let bytes = match ready!(Pin::new(&mut self.inner).poll_next(cx)) { + Some(Ok(bytes)) => bytes, + Some(Err(e)) => return Poll::Ready(Some(Err(map_err(e)))), + None => return Poll::Ready(None), }; log::trace!("poll; bytes={}B", bytes.len()); if let Some(frame) = self.decode_frame(bytes)? { log::debug!("received; frame={:?}", frame); - return Ok(Async::Ready(Some(frame))); + return Poll::Ready(Some(Ok(frame))); } } } } fn map_err(err: io::Error) -> RecvError { - use tokio_io::codec::length_delimited::FrameTooBig; - if let io::ErrorKind::InvalidData = err.kind() { if let Some(custom) = err.get_ref() { - if custom.is::() { + if custom.is::() { return RecvError::Connection(Reason::FRAME_SIZE_ERROR); } } diff --git a/src/codec/framed_write.rs b/src/codec/framed_write.rs index fa6ac18ee..cfcbdd46f 100644 --- a/src/codec/framed_write.rs +++ b/src/codec/framed_write.rs @@ -4,8 +4,10 @@ use crate::frame::{self, Frame, FrameSize}; use crate::hpack; use bytes::{Buf, BufMut, BytesMut}; -use futures::*; -use tokio_io::{AsyncRead, AsyncWrite, try_nb}; +use futures::ready; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio_io::{AsyncRead, AsyncWrite}; use std::io::{self, Cursor}; @@ -55,12 +57,12 @@ const CHAIN_THRESHOLD: usize = 256; // TODO: Make generic impl FramedWrite where - T: AsyncWrite, + T: AsyncWrite + Unpin, B: Buf, { pub fn new(inner: T) -> FramedWrite { FramedWrite { - inner: inner, + inner, hpack: hpack::Encoder::default(), buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)), next: None, @@ -73,17 +75,17 @@ where /// /// Calling this function may result in the current contents of the buffer /// to be flushed to `T`. - pub fn poll_ready(&mut self) -> Poll<(), io::Error> { + pub fn poll_ready(&mut self, cx: &mut Context) -> Poll> { if !self.has_capacity() { // Try flushing - self.flush()?; + ready!(self.flush(cx))?; if !self.has_capacity() { - return Ok(Async::NotReady); + return Poll::Pending; } } - Ok(Async::Ready(())) + Poll::Ready(Ok(())) } /// Buffer a frame. @@ -123,33 +125,33 @@ where // Save off the last frame... self.last_data_frame = Some(v); } - }, + } Frame::Headers(v) => { if let Some(continuation) = v.encode(&mut self.hpack, self.buf.get_mut()) { self.next = Some(Next::Continuation(continuation)); } - }, + } Frame::PushPromise(v) => { if let Some(continuation) = v.encode(&mut self.hpack, self.buf.get_mut()) { self.next = Some(Next::Continuation(continuation)); } - }, + } Frame::Settings(v) => { v.encode(self.buf.get_mut()); log::trace!("encoded settings; rem={:?}", self.buf.remaining()); - }, + } Frame::GoAway(v) => { v.encode(self.buf.get_mut()); log::trace!("encoded go_away; rem={:?}", self.buf.remaining()); - }, + } Frame::Ping(v) => { v.encode(self.buf.get_mut()); log::trace!("encoded ping; rem={:?}", self.buf.remaining()); - }, + } Frame::WindowUpdate(v) => { v.encode(self.buf.get_mut()); log::trace!("encoded window_update; rem={:?}", self.buf.remaining()); - }, + } Frame::Priority(_) => { /* @@ -157,18 +159,18 @@ where log::trace!("encoded priority; rem={:?}", self.buf.remaining()); */ unimplemented!(); - }, + } Frame::Reset(v) => { v.encode(self.buf.get_mut()); log::trace!("encoded reset; rem={:?}", self.buf.remaining()); - }, + } } Ok(()) } /// Flush buffered data to the wire - pub fn flush(&mut self) -> Poll<(), io::Error> { + pub fn flush(&mut self, cx: &mut Context) -> Poll> { log::trace!("flush"); loop { @@ -177,12 +179,12 @@ where Some(Next::Data(ref mut frame)) => { log::trace!(" -> queued data frame"); let mut buf = Buf::by_ref(&mut self.buf).chain(frame.payload_mut()); - try_ready!(self.inner.write_buf(&mut buf)); - }, + ready!(Pin::new(&mut self.inner).poll_write_buf(cx, &mut buf))?; + } _ => { log::trace!(" -> not a queued data frame"); - try_ready!(self.inner.write_buf(&mut self.buf)); - }, + ready!(Pin::new(&mut self.inner).poll_write_buf(cx, &mut self.buf))?; + } } } @@ -196,11 +198,10 @@ where self.last_data_frame = Some(frame); debug_assert!(self.is_empty()); break; - }, + } Some(Next::Continuation(frame)) => { // Buffer the continuation frame, then try to write again if let Some(continuation) = frame.encode(&mut self.hpack, self.buf.get_mut()) { - // We previously had a CONTINUATION, and after encoding // it, we got *another* one? Let's just double check // that at least some progress is being made... @@ -213,7 +214,7 @@ where self.next = Some(Next::Continuation(continuation)); } - }, + } None => { break; } @@ -222,15 +223,15 @@ where log::trace!("flushing buffer"); // Flush the upstream - try_nb!(self.inner.flush()); + ready!(Pin::new(&mut self.inner).poll_flush(cx))?; - Ok(Async::Ready(())) + Poll::Ready(Ok(())) } /// Close the codec - pub fn shutdown(&mut self) -> Poll<(), io::Error> { - try_ready!(self.flush()); - self.inner.shutdown().map_err(Into::into) + pub fn shutdown(&mut self, cx: &mut Context) -> Poll> { + ready!(self.flush(cx))?; + Pin::new(&mut self.inner).poll_shutdown(cx) } fn has_capacity(&self) -> bool { @@ -267,22 +268,25 @@ impl FramedWrite { } } -impl io::Read for FramedWrite { - fn read(&mut self, dst: &mut [u8]) -> io::Result { - self.inner.read(dst) +impl AsyncRead for FramedWrite { + unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { + self.inner.prepare_uninitialized_buffer(buf) } -} -impl AsyncRead for FramedWrite { - fn read_buf(&mut self, buf: &mut B2) -> Poll - where - Self: Sized, - { - self.inner.read_buf(buf) + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut self.inner).poll_read(cx, buf) } - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - self.inner.prepare_uninitialized_buffer(buf) + fn poll_read_buf( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut Buf, + ) -> Poll> { + Pin::new(&mut self.inner).poll_read_buf(cx, buf) } } diff --git a/src/codec/mod.rs b/src/codec/mod.rs index 3ae00eb03..322174cf1 100644 --- a/src/codec/mod.rs +++ b/src/codec/mod.rs @@ -14,10 +14,11 @@ use crate::frame::{self, Data, Frame}; use futures::*; -use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_io::codec::length_delimited; - use bytes::Buf; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio_codec::length_delimited; +use tokio_io::{AsyncRead, AsyncWrite}; use std::io; @@ -28,8 +29,8 @@ pub struct Codec { impl Codec where - T: AsyncRead + AsyncWrite, - B: Buf, + T: AsyncRead + AsyncWrite + Unpin, + B: Buf + Unpin, { /// Returns a new `Codec` with the default max frame size #[inline] @@ -55,9 +56,7 @@ where // Use FramedRead's method since it checks the value is within range. inner.set_max_frame_size(max_frame_size); - Codec { - inner, - } + Codec { inner } } } @@ -121,12 +120,12 @@ impl Codec { impl Codec where - T: AsyncWrite, - B: Buf, + T: AsyncWrite + Unpin, + B: Buf + Unpin, { /// Returns `Ready` when the codec can buffer a frame - pub fn poll_ready(&mut self) -> Poll<(), io::Error> { - self.framed_write().poll_ready() + pub fn poll_ready(&mut self, cx: &mut Context) -> Poll> { + self.framed_write().poll_ready(cx) } /// Buffer a frame. @@ -140,60 +139,59 @@ where } /// Flush buffered data to the wire - pub fn flush(&mut self) -> Poll<(), io::Error> { - self.framed_write().flush() + pub fn flush(&mut self, cx: &mut Context) -> Poll> { + self.framed_write().flush(cx) } /// Shutdown the send half - pub fn shutdown(&mut self) -> Poll<(), io::Error> { - self.framed_write().shutdown() + pub fn shutdown(&mut self, cx: &mut Context) -> Poll> { + self.framed_write().shutdown(cx) } } impl Stream for Codec where - T: AsyncRead, + T: AsyncRead + Unpin, + B: Unpin, { - type Item = Frame; - type Error = RecvError; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - self.inner.poll() + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_next(cx) } } -impl Sink for Codec +impl Sink> for Codec where - T: AsyncWrite, - B: Buf, + T: AsyncWrite + Unpin, + B: Buf + Unpin, { - type SinkItem = Frame; - type SinkError = SendError; - - fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - if !self.poll_ready()?.is_ready() { - return Ok(AsyncSink::NotReady(item)); - } + type Error = SendError; - self.buffer(item)?; - Ok(AsyncSink::Ready) + fn start_send(mut self: Pin<&mut Self>, item: Frame) -> Result<(), Self::Error> { + Codec::buffer(&mut self, item)?; + Ok(()) + } + /// Returns `Ready` when the codec can buffer a frame + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.framed_write().poll_ready(cx).map_err(Into::into) } - fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - self.flush()?; - Ok(Async::Ready(())) + /// Flush buffered data to the wire + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.framed_write().flush(cx).map_err(Into::into) } - fn close(&mut self) -> Poll<(), Self::SinkError> { - self.shutdown()?; - Ok(Async::Ready(())) + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.shutdown(cx))?; + Poll::Ready(Ok(())) } } // TODO: remove (or improve) this impl From for Codec> where - T: AsyncRead + AsyncWrite, + T: AsyncRead + AsyncWrite + Unpin, { fn from(src: T) -> Self { Self::new(src) diff --git a/src/lib.rs b/src/lib.rs index 41a2d240b..a03ea9669 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -81,6 +81,7 @@ #![doc(html_root_url = "https://docs.rs/h2/0.1.25")] #![deny(missing_debug_implementations, missing_docs)] #![cfg_attr(test, deny(warnings))] +#![feature(async_await)] macro_rules! proto_err { (conn: $($msg:tt)+) => { @@ -91,9 +92,9 @@ macro_rules! proto_err { }; } -mod error; #[cfg_attr(feature = "unstable", allow(missing_docs))] mod codec; +mod error; mod hpack; mod proto; @@ -109,7 +110,48 @@ pub mod server; mod share; pub use crate::error::{Error, Reason}; -pub use crate::share::{SendStream, StreamId, RecvStream, ReleaseCapacity, PingPong, Ping, Pong}; +pub use crate::share::{Ping, PingPong, Pong, RecvStream, ReleaseCapacity, SendStream, StreamId}; #[cfg(feature = "unstable")] pub use codec::{Codec, RecvError, SendError, UserError}; + +use std::task::Poll; + +// TODO: Get rid of this trait once https://github.com/rust-lang/rust/pull/63512 +// is stablized. +trait PollExt { + /// Changes the success value of this `Poll` with the closure provided. + fn map_ok_(self, f: F) -> Poll>> + where + F: FnOnce(T) -> U; + /// Changes the error value of this `Poll` with the closure provided. + fn map_err_(self, f: F) -> Poll>> + where + F: FnOnce(E) -> U; +} + +impl PollExt for Poll>> { + fn map_ok_(self, f: F) -> Poll>> + where + F: FnOnce(T) -> U, + { + match self { + Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(Ok(f(t)))), + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + + fn map_err_(self, f: F) -> Poll>> + where + F: FnOnce(E) -> U, + { + match self { + Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(Ok(t))), + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(f(e)))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} diff --git a/src/proto/connection.rs b/src/proto/connection.rs index 6183369df..3b1f7bb33 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -1,17 +1,18 @@ -use crate::{client, frame, proto, server}; use crate::codec::RecvError; use crate::frame::{Reason, StreamId}; +use crate::{client, frame, proto, server}; use crate::frame::DEFAULT_INITIAL_WINDOW_SIZE; use crate::proto::*; use bytes::{Bytes, IntoBuf}; -use futures::{Stream, try_ready}; -use tokio_io::{AsyncRead, AsyncWrite}; - -use std::marker::PhantomData; +use futures::{ready, Stream}; use std::io; +use std::marker::PhantomData; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Duration; +use tokio_io::{AsyncRead, AsyncWrite}; /// An H2 connection #[derive(Debug)] @@ -70,16 +71,15 @@ enum State { impl Connection where - T: AsyncRead + AsyncWrite, + T: AsyncRead + AsyncWrite + Unpin, P: Peer, - B: IntoBuf, + B: IntoBuf + Unpin, + B::Buf: Unpin, { - pub fn new( - codec: Codec>, - config: Config, - ) -> Connection { + pub fn new(codec: Codec>, config: Config) -> Connection { let streams = Streams::new(streams::Config { - local_init_window_sz: config.settings + local_init_window_sz: config + .settings .initial_window_size() .unwrap_or(DEFAULT_INITIAL_WINDOW_SIZE), initial_max_send_streams: config.initial_max_send_streams, @@ -88,7 +88,8 @@ where local_reset_duration: config.reset_stream_duration, local_reset_max: config.reset_stream_max, remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, - remote_max_initiated: config.settings + remote_max_initiated: config + .settings .max_concurrent_streams() .map(|max| max as usize), }); @@ -112,25 +113,24 @@ where /// /// Returns `RecvError` as this may raise errors that are caused by delayed /// processing of received frames. - fn poll_ready(&mut self) -> Poll<(), RecvError> { + fn poll_ready(&mut self, cx: &mut Context) -> Poll> { // The order of these calls don't really matter too much - try_ready!(self.ping_pong.send_pending_pong(&mut self.codec)); - try_ready!(self.ping_pong.send_pending_ping(&mut self.codec)); - try_ready!( - self.settings - .send_pending_ack(&mut self.codec, &mut self.streams) - ); - try_ready!(self.streams.send_pending_refusal(&mut self.codec)); - - Ok(().into()) + ready!(self.ping_pong.send_pending_pong(cx, &mut self.codec))?; + ready!(self.ping_pong.send_pending_ping(cx, &mut self.codec))?; + ready!(self + .settings + .send_pending_ack(cx, &mut self.codec, &mut self.streams))?; + ready!(self.streams.send_pending_refusal(cx, &mut self.codec))?; + + Poll::Ready(Ok(())) } /// Send any pending GOAWAY frames. /// /// This will return `Some(reason)` if the connection should be closed /// afterwards. If this is a graceful shutdown, this returns `None`. - fn poll_go_away(&mut self) -> Poll, io::Error> { - self.go_away.send_pending_go_away(&mut self.codec) + fn poll_go_away(&mut self, cx: &mut Context) -> Poll>> { + self.go_away.send_pending_go_away(cx, &mut self.codec) } fn go_away(&mut self, id: StreamId, e: Reason) { @@ -154,7 +154,7 @@ where self.streams.recv_err(&proto::Error::Proto(e)); } - fn take_error(&mut self, ours: Reason) -> Poll<(), proto::Error> { + fn take_error(&mut self, ours: Reason) -> Poll> { let reason = if let Some(theirs) = self.error.take() { match (ours, theirs) { // If either side reported an error, return that @@ -171,9 +171,9 @@ where }; if reason == Reason::NO_ERROR { - Ok(().into()) + Poll::Ready(Ok(())) } else { - Err(proto::Error::Proto(reason)) + Poll::Ready(Err(proto::Error::Proto(reason))) } } @@ -192,7 +192,7 @@ where } /// Advances the internal state of the connection. - pub fn poll(&mut self) -> Poll<(), proto::Error> { + pub fn poll(&mut self, cx: &mut Context) -> Poll> { use crate::codec::RecvError::*; loop { @@ -200,15 +200,15 @@ where match self.state { // When open, continue to poll a frame State::Open => { - match self.poll2() { + match self.poll2(cx) { // The connection has shutdown normally - Ok(Async::Ready(())) => self.state = State::Closing(Reason::NO_ERROR), + Poll::Ready(Ok(())) => self.state = State::Closing(Reason::NO_ERROR), // The connection is not ready to make progress - Ok(Async::NotReady) => { + Poll::Pending => { // Ensure all window updates have been sent. // // This will also handle flushing `self.codec` - try_ready!(self.streams.poll_complete(&mut self.codec)); + ready!(self.streams.poll_complete(cx, &mut self.codec))?; if self.error.is_some() || self.go_away.should_close_on_idle() { if !self.streams.has_streams() { @@ -217,12 +217,12 @@ where } } - return Ok(Async::NotReady); - }, + return Poll::Pending; + } // Attempting to read a frame resulted in a connection level // error. This is handled by setting a GOAWAY frame followed by // terminating the connection. - Err(Connection(e)) => { + Poll::Ready(Err(Connection(e))) => { log::debug!("Connection::poll; connection error={:?}", e); // We may have already sent a GOAWAY for this error, @@ -238,22 +238,19 @@ where // Reset all active streams self.streams.recv_err(&e.into()); self.go_away_now(e); - }, + } // Attempting to read a frame resulted in a stream level error. // This is handled by resetting the frame then trying to read // another frame. - Err(Stream { - id, - reason, - }) => { + Poll::Ready(Err(Stream { id, reason })) => { log::trace!("stream error; id={:?}; reason={:?}", id, reason); self.streams.send_reset(id, reason); - }, + } // Attempting to read a frame resulted in an I/O error. All // active streams must be reset. // // TODO: Are I/O errors recoverable? - Err(Io(e)) => { + Poll::Ready(Err(Io(e))) => { log::debug!("Connection::poll; IO error={:?}", e); let e = e.into(); @@ -261,24 +258,24 @@ where self.streams.recv_err(&e); // Return the error - return Err(e); - }, + return Poll::Ready(Err(e)); + } } } State::Closing(reason) => { log::trace!("connection closing after flush"); // Flush/shutdown the codec - try_ready!(self.codec.shutdown()); + ready!(self.codec.shutdown(cx))?; // Transition the state to error self.state = State::Closed(reason); - }, + } State::Closed(reason) => return self.take_error(reason), } } } - fn poll2(&mut self) -> Poll<(), RecvError> { + fn poll2(&mut self, cx: &mut Context) -> Poll> { use crate::frame::Frame::*; // This happens outside of the loop to prevent needing to do a clock @@ -292,42 +289,49 @@ where // The order here matters: // - poll_go_away may buffer a graceful shutdown GOAWAY frame // - If it has, we've also added a PING to be sent in poll_ready - if let Some(reason) = try_ready!(self.poll_go_away()) { - if self.go_away.should_close_now() { - if self.go_away.is_user_initiated() { - // A user initiated abrupt shutdown shouldn't return - // the same error back to the user. - return Ok(Async::Ready(())); - } else { - return Err(RecvError::Connection(reason)); + match ready!(self.poll_go_away(cx)?) { + Some(reason) => { + if self.go_away.should_close_now() { + if self.go_away.is_user_initiated() { + // A user initiated abrupt shutdown shouldn't return + // the same error back to the user. + return Poll::Ready(Ok(())); + } else { + return Poll::Ready(Err(RecvError::Connection(reason))); + } } + // Only NO_ERROR should be waiting for idle + debug_assert_eq!( + reason, + Reason::NO_ERROR, + "graceful GOAWAY should be NO_ERROR" + ); } - // Only NO_ERROR should be waiting for idle - debug_assert_eq!(reason, Reason::NO_ERROR, "graceful GOAWAY should be NO_ERROR"); + None => (), } - try_ready!(self.poll_ready()); + ready!(self.poll_ready(cx))?; - match try_ready!(self.codec.poll()) { + match ready!(Pin::new(&mut self.codec).poll_next(cx)?) { Some(Headers(frame)) => { log::trace!("recv HEADERS; frame={:?}", frame); self.streams.recv_headers(frame)?; - }, + } Some(Data(frame)) => { log::trace!("recv DATA; frame={:?}", frame); self.streams.recv_data(frame)?; - }, + } Some(Reset(frame)) => { log::trace!("recv RST_STREAM; frame={:?}", frame); self.streams.recv_reset(frame)?; - }, + } Some(PushPromise(frame)) => { log::trace!("recv PUSH_PROMISE; frame={:?}", frame); self.streams.recv_push_promise(frame)?; - }, + } Some(Settings(frame)) => { log::trace!("recv SETTINGS; frame={:?}", frame); self.settings.recv_settings(frame); - }, + } Some(GoAway(frame)) => { log::trace!("recv GOAWAY; frame={:?}", frame); // This should prevent starting new streams, @@ -336,7 +340,7 @@ where // transition to GoAway. self.streams.recv_go_away(&frame)?; self.error = Some(frame.reason()); - }, + } Some(Ping(frame)) => { log::trace!("recv PING; frame={:?}", frame); let status = self.ping_pong.recv_ping(frame); @@ -349,21 +353,20 @@ where let last_processed_id = self.streams.last_processed_id(); self.go_away(last_processed_id, Reason::NO_ERROR); } - }, + } Some(WindowUpdate(frame)) => { log::trace!("recv WINDOW_UPDATE; frame={:?}", frame); self.streams.recv_window_update(frame)?; - }, + } Some(Priority(frame)) => { log::trace!("recv PRIORITY; frame={:?}", frame); // TODO: handle - }, + } None => { log::trace!("codec closed"); - self.streams.recv_eof(false) - .ok().expect("mutex poisoned"); - return Ok(Async::Ready(())); - }, + self.streams.recv_eof(false).ok().expect("mutex poisoned"); + return Poll::Ready(Ok(())); + } } } } @@ -385,8 +388,9 @@ where impl Connection where - T: AsyncRead + AsyncWrite, - B: IntoBuf, + T: AsyncRead + AsyncWrite + Unpin, + B: IntoBuf + Unpin, + B::Buf: Unpin, { pub fn next_incoming(&mut self) -> Option> { self.streams.next_incoming() diff --git a/src/proto/go_away.rs b/src/proto/go_away.rs index 42ff0f0a9..1ac2f2e0f 100644 --- a/src/proto/go_away.rs +++ b/src/proto/go_away.rs @@ -2,8 +2,8 @@ use crate::codec::Codec; use crate::frame::{self, Reason, StreamId}; use bytes::Buf; -use futures::{Async, Poll}; use std::io; +use std::task::{Context, Poll}; use tokio_io::AsyncWrite; /// Manages our sending of GOAWAY frames. @@ -59,7 +59,7 @@ impl GoAway { assert!( f.last_stream_id() <= going_away.last_processed_id, "GOAWAY stream IDs shouldn't be higher; \ - last_processed_id = {:?}, f.last_stream_id() = {:?}", + last_processed_id = {:?}, f.last_stream_id() = {:?}", going_away.last_processed_id, f.last_stream_id(), ); @@ -76,8 +76,8 @@ impl GoAway { self.close_now = true; if let Some(ref going_away) = self.going_away { // Prevent sending the same GOAWAY twice. - if going_away.last_processed_id == f.last_stream_id() - && going_away.reason == f.reason() { + if going_away.last_processed_id == f.last_stream_id() && going_away.reason == f.reason() + { return; } } @@ -100,9 +100,7 @@ impl GoAway { /// Return the last Reason we've sent. pub fn going_away_reason(&self) -> Option { - self.going_away - .as_ref() - .map(|g| g.reason) + self.going_away.as_ref().map(|g| g.reason) } /// Returns if the connection should close now, or wait until idle. @@ -112,36 +110,43 @@ impl GoAway { /// Returns if the connection should be closed when idle. pub fn should_close_on_idle(&self) -> bool { - !self.close_now && self.going_away - .as_ref() - .map(|g| g.last_processed_id != StreamId::MAX) - .unwrap_or(false) + !self.close_now + && self + .going_away + .as_ref() + .map(|g| g.last_processed_id != StreamId::MAX) + .unwrap_or(false) } /// Try to write a pending GOAWAY frame to the buffer. /// /// If a frame is written, the `Reason` of the GOAWAY is returned. - pub fn send_pending_go_away(&mut self, dst: &mut Codec) -> Poll, io::Error> + pub fn send_pending_go_away( + &mut self, + cx: &mut Context, + dst: &mut Codec, + ) -> Poll>> where - T: AsyncWrite, - B: Buf, + T: AsyncWrite + Unpin, + B: Buf + Unpin, { if let Some(frame) = self.pending.take() { - if !dst.poll_ready()?.is_ready() { + if !dst.poll_ready(cx)?.is_ready() { self.pending = Some(frame); - return Ok(Async::NotReady); + return Poll::Pending; } let reason = frame.reason(); - dst.buffer(frame.into()) - .ok() - .expect("invalid GOAWAY frame"); + dst.buffer(frame.into()).ok().expect("invalid GOAWAY frame"); - return Ok(Async::Ready(Some(reason))); + return Poll::Ready(Some(Ok(reason))); } else if self.should_close_now() { - return Ok(Async::Ready(self.going_away_reason())); + return match self.going_away_reason() { + Some(reason) => Poll::Ready(Some(Ok(reason))), + None => Poll::Ready(None), + }; } - Ok(Async::Ready(None)) + Poll::Ready(None) } } diff --git a/src/proto/mod.rs b/src/proto/mod.rs index ae43bdadb..4b6e0909a 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -8,10 +8,10 @@ mod streams; pub(crate) use self::connection::{Config, Connection}; pub(crate) use self::error::Error; -pub(crate) use self::peer::{Peer, Dyn as DynPeer}; +pub(crate) use self::peer::{Dyn as DynPeer, Peer}; pub(crate) use self::ping_pong::UserPings; -pub(crate) use self::streams::{StreamRef, OpaqueStreamRef, Streams}; -pub(crate) use self::streams::{PollReset, Prioritized, Open}; +pub(crate) use self::streams::{OpaqueStreamRef, StreamRef, Streams}; +pub(crate) use self::streams::{Open, PollReset, Prioritized}; use crate::codec::Codec; @@ -21,9 +21,6 @@ use self::settings::Settings; use crate::frame::{self, Frame}; -use futures::{task, Async, Poll}; -use futures::task::Task; - use bytes::Buf; use tokio_io::AsyncWrite; diff --git a/src/proto/ping_pong.rs b/src/proto/ping_pong.rs index bc24c8230..0dbbec2d5 100644 --- a/src/proto/ping_pong.rs +++ b/src/proto/ping_pong.rs @@ -3,11 +3,11 @@ use crate::frame::Ping; use crate::proto::{self, PingPayload}; use bytes::Buf; -use futures::{Async, Poll}; -use futures::task::AtomicTask; +use futures::task::AtomicWaker; use std::io; -use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; use tokio_io::AsyncWrite; /// Acknowledges ping requests from the remote. @@ -28,9 +28,9 @@ struct UserPingsRx(Arc); struct UserPingsInner { state: AtomicUsize, /// Task to wake up the main `Connection`. - ping_task: AtomicTask, + ping_task: AtomicWaker, /// Task to wake up `share::PingPong::poll_pong`. - pong_task: AtomicTask, + pong_task: AtomicWaker, } #[derive(Debug)] @@ -77,8 +77,8 @@ impl PingPong { let user_pings = Arc::new(UserPingsInner { state: AtomicUsize::new(USER_STATE_EMPTY), - ping_task: AtomicTask::new(), - pong_task: AtomicTask::new(), + ping_task: AtomicWaker::new(), + pong_task: AtomicWaker::new(), }); self.user_pings = Some(UserPingsRx(user_pings.clone())); Some(UserPings(user_pings)) @@ -135,34 +135,42 @@ impl PingPong { } /// Send any pending pongs. - pub(crate) fn send_pending_pong(&mut self, dst: &mut Codec) -> Poll<(), io::Error> + pub(crate) fn send_pending_pong( + &mut self, + cx: &mut Context, + dst: &mut Codec, + ) -> Poll> where - T: AsyncWrite, - B: Buf, + T: AsyncWrite + Unpin, + B: Buf + Unpin, { if let Some(pong) = self.pending_pong.take() { - if !dst.poll_ready()?.is_ready() { + if !dst.poll_ready(cx)?.is_ready() { self.pending_pong = Some(pong); - return Ok(Async::NotReady); + return Poll::Pending; } dst.buffer(Ping::pong(pong).into()) .expect("invalid pong frame"); } - Ok(Async::Ready(())) + Poll::Ready(Ok(())) } /// Send any pending pings. - pub(crate) fn send_pending_ping(&mut self, dst: &mut Codec) -> Poll<(), io::Error> + pub(crate) fn send_pending_ping( + &mut self, + cx: &mut Context, + dst: &mut Codec, + ) -> Poll> where - T: AsyncWrite, - B: Buf, + T: AsyncWrite + Unpin, + B: Buf + Unpin, { if let Some(ref mut ping) = self.pending_ping { if !ping.sent { - if !dst.poll_ready()?.is_ready() { - return Ok(Async::NotReady); + if !dst.poll_ready(cx)?.is_ready() { + return Poll::Pending; } dst.buffer(Ping::new(ping.payload).into()) @@ -171,19 +179,22 @@ impl PingPong { } } else if let Some(ref users) = self.user_pings { if users.0.state.load(Ordering::Acquire) == USER_STATE_PENDING_PING { - if !dst.poll_ready()?.is_ready() { - return Ok(Async::NotReady); + if !dst.poll_ready(cx)?.is_ready() { + return Poll::Pending; } dst.buffer(Ping::new(Ping::USER).into()) .expect("invalid ping frame"); - users.0.state.store(USER_STATE_PENDING_PONG, Ordering::Release); + users + .0 + .state + .store(USER_STATE_PENDING_PONG, Ordering::Release); } else { - users.0.ping_task.register(); + users.0.ping_task.register(cx.waker()); } } - Ok(Async::Ready(())) + Poll::Ready(Ok(())) } } @@ -201,19 +212,17 @@ impl ReceivedPing { impl UserPings { pub(crate) fn send_ping(&self) -> Result<(), Option> { let prev = self.0.state.compare_and_swap( - USER_STATE_EMPTY, // current + USER_STATE_EMPTY, // current USER_STATE_PENDING_PING, // new Ordering::AcqRel, ); match prev { USER_STATE_EMPTY => { - self.0.ping_task.notify(); + self.0.ping_task.wake(); Ok(()) - }, - USER_STATE_CLOSED => { - Err(Some(broken_pipe().into())) } + USER_STATE_CLOSED => Err(Some(broken_pipe().into())), _ => { // Was already pending, user error! Err(None) @@ -221,20 +230,20 @@ impl UserPings { } } - pub(crate) fn poll_pong(&self) -> Poll<(), proto::Error> { + pub(crate) fn poll_pong(&self, cx: &mut Context) -> Poll> { // Must register before checking state, in case state were to change // before we could register, and then the ping would just be lost. - self.0.pong_task.register(); + self.0.pong_task.register(cx.waker()); let prev = self.0.state.compare_and_swap( USER_STATE_RECEIVED_PONG, // current - USER_STATE_EMPTY, // new + USER_STATE_EMPTY, // new Ordering::AcqRel, ); match prev { - USER_STATE_RECEIVED_PONG => Ok(Async::Ready(())), - USER_STATE_CLOSED => Err(broken_pipe().into()), - _ => Ok(Async::NotReady), + USER_STATE_RECEIVED_PONG => Poll::Ready(Ok(())), + USER_STATE_CLOSED => Poll::Ready(Err(broken_pipe().into())), + _ => Poll::Pending, } } } @@ -244,13 +253,13 @@ impl UserPings { impl UserPingsRx { fn receive_pong(&self) -> bool { let prev = self.0.state.compare_and_swap( - USER_STATE_PENDING_PONG, // current + USER_STATE_PENDING_PONG, // current USER_STATE_RECEIVED_PONG, // new Ordering::AcqRel, ); if prev == USER_STATE_PENDING_PONG { - self.0.pong_task.notify(); + self.0.pong_task.wake(); true } else { false @@ -261,7 +270,7 @@ impl UserPingsRx { impl Drop for UserPingsRx { fn drop(&mut self) { self.0.state.store(USER_STATE_CLOSED, Ordering::Release); - self.0.pong_task.notify(); + self.0.pong_task.wake(); } } diff --git a/src/proto/settings.rs b/src/proto/settings.rs index 4007993ff..f35aefa0e 100644 --- a/src/proto/settings.rs +++ b/src/proto/settings.rs @@ -1,6 +1,7 @@ use crate::codec::RecvError; use crate::frame; use crate::proto::*; +use std::task::{Poll, Context}; #[derive(Debug)] pub(crate) struct Settings { @@ -29,21 +30,22 @@ impl Settings { pub fn send_pending_ack( &mut self, + cx: &mut Context, dst: &mut Codec, streams: &mut Streams, - ) -> Poll<(), RecvError> + ) -> Poll> where - T: AsyncWrite, - B: Buf, - C: Buf, + T: AsyncWrite + Unpin, + B: Buf + Unpin, + C: Buf + Unpin, P: Peer, { log::trace!("send_pending_ack; pending={:?}", self.pending); - if let Some(ref settings) = self.pending { - if !dst.poll_ready()?.is_ready() { + if let Some(settings) = &self.pending { + if !dst.poll_ready(cx)?.is_ready() { log::trace!("failed to send ACK"); - return Ok(Async::NotReady); + return Poll::Pending; } // Create an ACK settings frame @@ -65,6 +67,6 @@ impl Settings { self.pending = None; - Ok(().into()) + Poll::Ready(Ok(())) } } diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index fee9e5798..efa105053 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -7,10 +7,10 @@ use crate::codec::UserError; use crate::codec::UserError::*; use bytes::buf::Take; -use futures::try_ready; - +use futures::ready; use std::{cmp, fmt, mem}; use std::io; +use std::task::{Context, Poll, Waker}; /// # Warning /// @@ -104,14 +104,14 @@ impl Prioritize { frame: Frame, buffer: &mut Buffer>, stream: &mut store::Ptr, - task: &mut Option, + task: &mut Option, ) { // Queue the frame in the buffer stream.pending_send.push_back(buffer, frame); self.schedule_send(stream, task); } - pub fn schedule_send(&mut self, stream: &mut store::Ptr, task: &mut Option) { + pub fn schedule_send(&mut self, stream: &mut store::Ptr, task: &mut Option) { // If the stream is waiting to be opened, nothing more to do. if !stream.is_pending_open { log::trace!("schedule_send; {:?}", stream.id); @@ -120,7 +120,7 @@ impl Prioritize { // Notify the connection. if let Some(task) = task.take() { - task.notify(); + task.wake(); } } } @@ -136,7 +136,7 @@ impl Prioritize { buffer: &mut Buffer>, stream: &mut store::Ptr, counts: &mut Counts, - task: &mut Option, + task: &mut Option, ) -> Result<(), UserError> where B: Buf, @@ -483,17 +483,18 @@ impl Prioritize { pub fn poll_complete( &mut self, + cx: &mut Context, buffer: &mut Buffer>, store: &mut Store, counts: &mut Counts, dst: &mut Codec>, - ) -> Poll<(), io::Error> + ) -> Poll> where - T: AsyncWrite, - B: Buf, + T: AsyncWrite + Unpin, + B: Buf + Unpin, { // Ensure codec is ready - try_ready!(dst.poll_ready()); + ready!(dst.poll_ready(cx))?; // Reclaim any frame that has previously been written self.reclaim_frame(buffer, store, dst); @@ -517,18 +518,18 @@ impl Prioritize { dst.buffer(frame).ok().expect("invalid frame"); // Ensure the codec is ready to try the loop again. - try_ready!(dst.poll_ready()); + ready!(dst.poll_ready(cx))?; // Because, always try to reclaim... self.reclaim_frame(buffer, store, dst); }, None => { // Try to flush the codec. - try_ready!(dst.flush()); + ready!(dst.flush(cx))?; // This might release a data frame... if !self.reclaim_frame(buffer, store, dst) { - return Ok(().into()); + return Poll::Ready(Ok(())) } // No need to poll ready as poll_complete() does this for diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index eb03a680a..9c78f7c84 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -1,13 +1,15 @@ +use std::task::Context; use super::*; use crate::{frame, proto}; use crate::codec::{RecvError, UserError}; use crate::frame::{Reason, DEFAULT_INITIAL_WINDOW_SIZE}; use http::{HeaderMap, Response, Request, Method}; -use futures::try_ready; +use futures::ready; use std::io; use std::time::{Duration, Instant}; +use std::task::{Poll, Waker}; #[derive(Debug)] pub(super) struct Recv { @@ -257,15 +259,17 @@ impl Recv { /// Called by the client to get pushed response pub fn poll_pushed( - &mut self, stream: &mut store::Ptr - ) -> Poll, store::Key)>, proto::Error> { + &mut self, + cx: &Context, + stream: &mut store::Ptr + ) -> Poll, store::Key), proto::Error>>> { use super::peer::PollMessage::*; let mut ppp = stream.pending_push_promises.take(); let pushed = ppp.pop(stream.store_mut()).map( |mut pushed| match pushed.pending_recv.pop_front(&mut self.buffer) { Some(Event::Headers(Server(headers))) => - Async::Ready(Some((headers, pushed.key()))), + (headers, pushed.key()), // When frames are pushed into the queue, it is verified that // the first frame is a HEADERS frame. _ => panic!("Headers not set on pushed stream") @@ -273,15 +277,15 @@ impl Recv { ); stream.pending_push_promises = ppp; if let Some(p) = pushed { - Ok(p) + Poll::Ready(Some(Ok(p))) } else { let is_open = stream.state.ensure_recv_open()?; if is_open { - stream.recv_task = Some(task::current()); - Ok(Async::NotReady) + stream.recv_task = Some(cx.waker().clone()); + Poll::Pending } else { - Ok(Async::Ready(None)) + Poll::Ready(None) } } } @@ -289,20 +293,21 @@ impl Recv { /// Called by the client to get the response pub fn poll_response( &mut self, + cx: &Context, stream: &mut store::Ptr, - ) -> Poll, proto::Error> { + ) -> Poll, proto::Error>> { use super::peer::PollMessage::*; // If the buffer is not empty, then the first frame must be a HEADERS // frame or the user violated the contract. match stream.pending_recv.pop_front(&mut self.buffer) { - Some(Event::Headers(Client(response))) => Ok(response.into()), + Some(Event::Headers(Client(response))) => Poll::Ready(Ok(response.into())), Some(_) => panic!("poll_response called after response returned"), None => { stream.state.ensure_recv_open()?; - stream.recv_task = Some(task::current()); - Ok(Async::NotReady) + stream.recv_task = Some(cx.waker().clone()); + Poll::Pending }, } } @@ -339,7 +344,7 @@ impl Recv { pub fn release_connection_capacity( &mut self, capacity: WindowSize, - task: &mut Option, + task: &mut Option, ) { log::trace!( "release_connection_capacity; size={}, connection in_flight_data={}", @@ -355,7 +360,7 @@ impl Recv { if self.flow.unclaimed_capacity().is_some() { if let Some(task) = task.take() { - task.notify(); + task.wake(); } } } @@ -365,7 +370,7 @@ impl Recv { &mut self, capacity: WindowSize, stream: &mut store::Ptr, - task: &mut Option, + task: &mut Option, ) -> Result<(), UserError> { log::trace!("release_capacity; size={}", capacity); @@ -387,7 +392,7 @@ impl Recv { self.pending_window_updates.push(stream); if let Some(task) = task.take() { - task.notify(); + task.wake(); } } @@ -398,7 +403,7 @@ impl Recv { pub fn release_closed_capacity( &mut self, stream: &mut store::Ptr, - task: &mut Option, + task: &mut Option, ) { debug_assert_eq!(stream.ref_count, 0); @@ -433,7 +438,7 @@ impl Recv { /// /// The `task` is an optional parked task for the `Connection` that might /// be blocked on needing more window capacity. - pub fn set_target_connection_window(&mut self, target: WindowSize, task: &mut Option) { + pub fn set_target_connection_window(&mut self, target: WindowSize, task: &mut Option) { log::trace!( "set_target_connection_window; target={}; available={}, reserved={}", target, @@ -458,7 +463,7 @@ impl Recv { // a connection WINDOW_UPDATE. if self.flow.unclaimed_capacity().is_some() { if let Some(task) = task.take() { - task.notify(); + task.wake(); } } } @@ -824,14 +829,15 @@ impl Recv { /// Send any pending refusals. pub fn send_pending_refusal( &mut self, + cx: &mut Context, dst: &mut Codec>, - ) -> Poll<(), io::Error> + ) -> Poll> where - T: AsyncWrite, - B: Buf, + T: AsyncWrite + Unpin, + B: Buf + Unpin, { if let Some(stream_id) = self.refused { - try_ready!(dst.poll_ready()); + ready!(dst.poll_ready(cx))?; // Create the RST_STREAM frame let frame = frame::Reset::new(stream_id, Reason::REFUSED_STREAM); @@ -844,7 +850,7 @@ impl Recv { self.refused = None; - Ok(Async::Ready(())) + Poll::Ready(Ok(())) } pub fn clear_expired_reset_streams(&mut self, store: &mut Store, counts: &mut Counts) { @@ -894,37 +900,39 @@ impl Recv { pub fn poll_complete( &mut self, + cx: &mut Context, store: &mut Store, counts: &mut Counts, dst: &mut Codec>, - ) -> Poll<(), io::Error> + ) -> Poll> where - T: AsyncWrite, - B: Buf, + T: AsyncWrite + Unpin, + B: Buf + Unpin, { // Send any pending connection level window updates - try_ready!(self.send_connection_window_update(dst)); + ready!(self.send_connection_window_update(cx, dst))?; // Send any pending stream level window updates - try_ready!(self.send_stream_window_updates(store, counts, dst)); + ready!(self.send_stream_window_updates(cx, store, counts, dst))?; - Ok(().into()) + Poll::Ready(Ok(())) } /// Send connection level window update fn send_connection_window_update( &mut self, + cx: &mut Context, dst: &mut Codec>, - ) -> Poll<(), io::Error> + ) -> Poll> where - T: AsyncWrite, - B: Buf, + T: AsyncWrite + Unpin, + B: Buf + Unpin, { if let Some(incr) = self.flow.unclaimed_capacity() { let frame = frame::WindowUpdate::new(StreamId::zero(), incr); // Ensure the codec has capacity - try_ready!(dst.poll_ready()); + ready!(dst.poll_ready(cx))?; // Buffer the WINDOW_UPDATE frame dst.buffer(frame.into()) @@ -938,28 +946,29 @@ impl Recv { .expect("unexpected flow control state"); } - Ok(().into()) + Poll::Ready(Ok(())) } /// Send stream level window update pub fn send_stream_window_updates( &mut self, + cx: &mut Context, store: &mut Store, counts: &mut Counts, dst: &mut Codec>, - ) -> Poll<(), io::Error> + ) -> Poll> where - T: AsyncWrite, - B: Buf, + T: AsyncWrite + Unpin, + B: Buf + Unpin, { loop { // Ensure the codec has capacity - try_ready!(dst.poll_ready()); + ready!(dst.poll_ready(cx))?; // Get the next stream let stream = match self.pending_window_updates.pop(store) { Some(stream) => stream, - None => return Ok(().into()), + None => return Poll::Ready(Ok(())), }; counts.transition(stream, |_, stream| { @@ -1001,10 +1010,10 @@ impl Recv { self.pending_accept.pop(store).map(|ptr| ptr.key()) } - pub fn poll_data(&mut self, stream: &mut Stream) -> Poll, proto::Error> { + pub fn poll_data(&mut self, cx: &Context, stream: &mut Stream) -> Poll>> { // TODO: Return error when the stream is reset match stream.pending_recv.pop_front(&mut self.buffer) { - Some(Event::Data(payload)) => Ok(Some(payload).into()), + Some(Event::Data(payload)) => Poll::Ready(Some(Ok(payload))), Some(event) => { // Frame is trailer stream.pending_recv.push_front(&mut self.buffer, event); @@ -1020,36 +1029,37 @@ impl Recv { stream.notify_recv(); // No more data frames - Ok(None.into()) + Poll::Ready(None) }, - None => self.schedule_recv(stream), + None => self.schedule_recv(cx, stream), } } pub fn poll_trailers( &mut self, + cx: &Context, stream: &mut Stream, - ) -> Poll, proto::Error> { + ) -> Poll>> { match stream.pending_recv.pop_front(&mut self.buffer) { - Some(Event::Trailers(trailers)) => Ok(Some(trailers).into()), + Some(Event::Trailers(trailers)) => Poll::Ready(Some(Ok(trailers))), Some(event) => { // Frame is not trailers.. not ready to poll trailers yet. stream.pending_recv.push_front(&mut self.buffer, event); - Ok(Async::NotReady) + Poll::Pending }, - None => self.schedule_recv(stream), + None => self.schedule_recv(cx, stream), } } - fn schedule_recv(&mut self, stream: &mut Stream) -> Poll, proto::Error> { + fn schedule_recv(&mut self, cx: &Context, stream: &mut Stream) -> Poll>> { if stream.state.ensure_recv_open()? { // Request to get notified once more frames arrive - stream.recv_task = Some(task::current()); - Ok(Async::NotReady) + stream.recv_task = Some(cx.waker().clone()); + Poll::Pending } else { // No more frames will be received - Ok(None.into()) + Poll::Ready(None) } } } diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index 9bbd0438b..4a723ce5b 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -1,14 +1,13 @@ -use crate::codec::{RecvError, UserError}; -use crate::frame::{self, Reason}; use super::{ - store, Buffer, Codec, Config, Counts, Frame, Prioritize, - Prioritized, Store, Stream, StreamId, StreamIdOverflow, WindowSize, + store, Buffer, Codec, Config, Counts, Frame, Prioritize, Prioritized, Store, Stream, StreamId, + StreamIdOverflow, WindowSize, }; +use crate::codec::{RecvError, UserError}; +use crate::frame::{self, Reason}; use bytes::Buf; use http; -use futures::{Async, Poll}; -use futures::task::Task; +use std::task::{Context, Poll, Waker}; use tokio_io::AsyncWrite; use std::io; @@ -60,7 +59,7 @@ impl Send { buffer: &mut Buffer>, stream: &mut store::Ptr, counts: &mut Counts, - task: &mut Option, + task: &mut Option, ) -> Result<(), UserError> { log::trace!( "send_headers; frame={:?}; init_window={:?}", @@ -81,7 +80,6 @@ impl Send { if te != "trailers" { log::debug!("illegal connection-specific headers found"); return Err(UserError::MalformedHeaders); - } } @@ -103,7 +101,8 @@ impl Send { } // Queue the frame for sending - self.prioritize.queue_frame(frame.into(), buffer, stream, task); + self.prioritize + .queue_frame(frame.into(), buffer, stream, task); Ok(()) } @@ -115,7 +114,7 @@ impl Send { buffer: &mut Buffer>, stream: &mut store::Ptr, counts: &mut Counts, - task: &mut Option, + task: &mut Option, ) { let is_reset = stream.state.is_reset(); let is_closed = stream.state.is_closed(); @@ -125,7 +124,7 @@ impl Send { "send_reset(..., reason={:?}, stream={:?}, ..., \ is_reset={:?}; is_closed={:?}; pending_send.is_empty={:?}; \ state={:?} \ - ", + ", reason, stream.id, is_reset, @@ -151,7 +150,7 @@ impl Send { if is_closed && is_empty { log::trace!( " -> not sending explicit RST_STREAM ({:?} was closed \ - and send queue was flushed)", + and send queue was flushed)", stream.id ); return; @@ -166,7 +165,8 @@ impl Send { let frame = frame::Reset::new(stream.id, reason); log::trace!("send_reset -- queueing; frame={:?}", frame); - self.prioritize.queue_frame(frame.into(), buffer, stream, task); + self.prioritize + .queue_frame(frame.into(), buffer, stream, task); self.prioritize.reclaim_all_capacity(stream, counts); } @@ -175,7 +175,7 @@ impl Send { stream: &mut store::Ptr, reason: Reason, counts: &mut Counts, - task: &mut Option, + task: &mut Option, ) { if stream.state.is_closed() { // Stream is already closed, nothing more to do @@ -194,11 +194,13 @@ impl Send { buffer: &mut Buffer>, stream: &mut store::Ptr, counts: &mut Counts, - task: &mut Option, + task: &mut Option, ) -> Result<(), UserError> - where B: Buf, + where + B: Buf, { - self.prioritize.send_data(frame, buffer, stream, counts, task) + self.prioritize + .send_data(frame, buffer, stream, counts, task) } pub fn send_trailers( @@ -207,7 +209,7 @@ impl Send { buffer: &mut Buffer>, stream: &mut store::Ptr, counts: &mut Counts, - task: &mut Option, + task: &mut Option, ) -> Result<(), UserError> { // TODO: Should this logic be moved into state.rs? if !stream.state.is_send_streaming() { @@ -221,7 +223,8 @@ impl Send { stream.state.send_close(); log::trace!("send_trailers -- queuing; frame={:?}", frame); - self.prioritize.queue_frame(frame.into(), buffer, stream, task); + self.prioritize + .queue_frame(frame.into(), buffer, stream, task); // Release any excess capacity self.prioritize.reserve_capacity(0, stream, counts); @@ -231,15 +234,18 @@ impl Send { pub fn poll_complete( &mut self, + cx: &mut Context, buffer: &mut Buffer>, store: &mut Store, counts: &mut Counts, dst: &mut Codec>, - ) -> Poll<(), io::Error> - where T: AsyncWrite, - B: Buf, + ) -> Poll> + where + T: AsyncWrite + Unpin, + B: Buf + Unpin, { - self.prioritize.poll_complete(buffer, store, counts, dst) + self.prioritize + .poll_complete(cx, buffer, store, counts, dst) } /// Request capacity to send data @@ -247,27 +253,28 @@ impl Send { &mut self, capacity: WindowSize, stream: &mut store::Ptr, - counts: &mut Counts) - { + counts: &mut Counts, + ) { self.prioritize.reserve_capacity(capacity, stream, counts) } pub fn poll_capacity( &mut self, + cx: &Context, stream: &mut store::Ptr, - ) -> Poll, UserError> { + ) -> Poll>> { if !stream.state.is_send_streaming() { - return Ok(Async::Ready(None)); + return Poll::Ready(None); } if !stream.send_capacity_inc { - stream.wait_send(); - return Ok(Async::NotReady); + stream.wait_send(cx); + return Poll::Pending; } stream.send_capacity_inc = false; - Ok(Async::Ready(Some(self.capacity(stream)))) + Poll::Ready(Some(Ok(self.capacity(stream)))) } /// Current available stream send capacity @@ -284,15 +291,16 @@ impl Send { pub fn poll_reset( &self, + cx: &Context, stream: &mut Stream, mode: PollReset, - ) -> Poll { + ) -> Poll> { match stream.state.ensure_reason(mode)? { - Some(reason) => Ok(reason.into()), + Some(reason) => Poll::Ready(Ok(reason)), None => { - stream.wait_send(); - Ok(Async::NotReady) - }, + stream.wait_send(cx); + Poll::Pending + } } } @@ -312,14 +320,18 @@ impl Send { buffer: &mut Buffer>, stream: &mut store::Ptr, counts: &mut Counts, - task: &mut Option, + task: &mut Option, ) -> Result<(), Reason> { if let Err(e) = self.prioritize.recv_stream_window_update(sz, stream) { log::debug!("recv_stream_window_update !!; err={:?}", e); self.send_reset( Reason::FLOW_CONTROL_ERROR.into(), - buffer, stream, counts, task); + buffer, + stream, + counts, + task, + ); return Err(e); } @@ -344,7 +356,7 @@ impl Send { buffer: &mut Buffer>, store: &mut Store, counts: &mut Counts, - task: &mut Option, + task: &mut Option, ) -> Result<(), RecvError> { // Applies an update to the remote endpoint's initial window size. // @@ -444,16 +456,14 @@ impl Send { } pub fn ensure_next_stream_id(&self) -> Result { - self.next_stream_id.map_err(|_| UserError::OverflowedStreamId) + self.next_stream_id + .map_err(|_| UserError::OverflowedStreamId) } pub fn may_have_created_stream(&self, id: StreamId) -> bool { if let Ok(next_id) = self.next_stream_id { // Peer::is_local_init should have been called beforehand - debug_assert_eq!( - id.is_server_initiated(), - next_id.is_server_initiated(), - ); + debug_assert_eq!(id.is_server_initiated(), next_id.is_server_initiated(),); id < next_id } else { true diff --git a/src/proto/streams/stream.rs b/src/proto/streams/stream.rs index c677a4a4e..d3caf5ca0 100644 --- a/src/proto/streams/stream.rs +++ b/src/proto/streams/stream.rs @@ -2,6 +2,7 @@ use super::*; use std::time::Instant; use std::usize; +use std::task::{Context, Waker}; /// Tracks Stream related state /// @@ -47,7 +48,7 @@ pub(super) struct Stream { pub buffered_send_data: WindowSize, /// Task tracking additional send capacity (i.e. window updates). - send_task: Option, + send_task: Option, /// Frames pending for this stream being sent to the socket pub pending_send: buffer::Deque, @@ -96,7 +97,7 @@ pub(super) struct Stream { pub pending_recv: buffer::Deque, /// Task tracking receiving frames - pub recv_task: Option, + pub recv_task: Option, /// The stream's pending push promises pub pending_push_promises: store::Queue, @@ -280,17 +281,17 @@ impl Stream { pub fn notify_send(&mut self) { if let Some(task) = self.send_task.take() { - task.notify(); + task.wake(); } } - pub fn wait_send(&mut self) { - self.send_task = Some(task::current()); + pub fn wait_send(&mut self, cx: &Context) { + self.send_task = Some(cx.waker().clone()); } pub fn notify_recv(&mut self) { if let Some(task) = self.recv_task.take() { - task.notify(); + task.wake(); } } } diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index 7ba818b2e..59f74aa19 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -1,18 +1,20 @@ -use crate::{client, proto, server}; -use crate::codec::{Codec, RecvError, SendError, UserError}; -use crate::frame::{self, Frame, Reason}; -use crate::proto::{peer, Peer, Open, WindowSize}; -use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId}; use super::recv::RecvHeaderBlockError; use super::store::{self, Entry, Resolve, Store}; +use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId}; +use crate::codec::{Codec, RecvError, SendError, UserError}; +use crate::frame::{self, Frame, Reason}; +use crate::proto::{peer, Open, Peer, WindowSize}; +use crate::{client, proto, server}; use bytes::{Buf, Bytes}; -use futures::{task, Async, Poll, try_ready}; +use futures::ready; use http::{HeaderMap, Request, Response}; +use std::task::{Context, Poll, Waker}; use tokio_io::AsyncWrite; -use std::{fmt, io}; +use crate::PollExt; use std::sync::{Arc, Mutex}; +use std::{fmt, io}; #[derive(Debug)] pub(crate) struct Streams @@ -77,7 +79,7 @@ struct Actions { send: Send, /// Task that calls `poll_complete`. - task: Option, + task: Option, /// If the connection errors, a copy is kept for any StreamRefs. conn_error: Option, @@ -93,7 +95,7 @@ struct SendBuffer { impl Streams where - B: Buf, + B: Buf + Unpin, P: Peer, { pub fn new(config: Config) -> Self { @@ -134,7 +136,11 @@ where // The GOAWAY process has begun. All streams with a greater ID than // specified as part of GOAWAY should be ignored. if id > me.actions.recv.max_stream_id() { - log::trace!("id ({:?}) > max_stream_id ({:?}), ignoring HEADERS", id, me.actions.recv.max_stream_id()); + log::trace!( + "id ({:?}) > max_stream_id ({:?}), ignoring HEADERS", + id, + me.actions.recv.max_stream_id() + ); return Ok(()); } @@ -170,10 +176,10 @@ where ); e.insert(stream) - }, + } None => return Ok(()), } - }, + } }; let stream = me.store.resolve(key); @@ -254,15 +260,16 @@ where // The GOAWAY process has begun. All streams with a greater ID // than specified as part of GOAWAY should be ignored. if id > me.actions.recv.max_stream_id() { - log::trace!("id ({:?}) > max_stream_id ({:?}), ignoring DATA", id, me.actions.recv.max_stream_id()); + log::trace!( + "id ({:?}) > max_stream_id ({:?}), ignoring DATA", + id, + me.actions.recv.max_stream_id() + ); return Ok(()); } if me.actions.may_have_forgotten_stream::

(id) { - log::debug!( - "recv_data for old stream={:?}, sending STREAM_CLOSED", - id, - ); + log::debug!("recv_data for old stream={:?}, sending STREAM_CLOSED", id,); let sz = frame.payload().len(); // This should have been enforced at the codec::FramedRead layer, so @@ -279,7 +286,7 @@ where proto_err!(conn: "recv_data: stream not found; id={:?}", id); return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); - }, + } }; let actions = &mut me.actions; @@ -294,7 +301,9 @@ where // we won't give the data to the user, and so they can't // release the capacity. We do it automatically. if let Err(RecvError::Stream { .. }) = res { - actions.recv.release_connection_capacity(sz as WindowSize, &mut None); + actions + .recv + .release_connection_capacity(sz as WindowSize, &mut None); } actions.reset_on_recv_stream_err(send_buffer, stream, counts, res) }) @@ -314,7 +323,11 @@ where // The GOAWAY process has begun. All streams with a greater ID than // specified as part of GOAWAY should be ignored. if id > me.actions.recv.max_stream_id() { - log::trace!("id ({:?}) > max_stream_id ({:?}), ignoring RST_STREAM", id, me.actions.recv.max_stream_id()); + log::trace!( + "id ({:?}) > max_stream_id ({:?}), ignoring RST_STREAM", + id, + me.actions.recv.max_stream_id() + ); return Ok(()); } @@ -327,7 +340,7 @@ where .map_err(RecvError::Connection)?; return Ok(()); - }, + } }; let mut send_buffer = self.send_buffer.inner.lock().unwrap(); @@ -400,14 +413,16 @@ where actions.recv.go_away(last_stream_id); me.store - .for_each(|stream| if stream.id > last_stream_id { - counts.transition(stream, |counts, stream| { - actions.recv.recv_err(&err, &mut *stream); - actions.send.recv_err(send_buffer, stream, counts); + .for_each(|stream| { + if stream.id > last_stream_id { + counts.transition(stream, |counts, stream| { + actions.recv.recv_err(&err, &mut *stream); + actions.send.recv_err(send_buffer, stream, counts); + Ok::<_, ()>(()) + }) + } else { Ok::<_, ()>(()) - }) - } else { - Ok::<_, ()>(()) + } }) .unwrap(); @@ -470,7 +485,11 @@ where // The GOAWAY process has begun. All streams with a greater ID // than specified as part of GOAWAY should be ignored. if id > me.actions.recv.max_stream_id() { - log::trace!("id ({:?}) > max_stream_id ({:?}), ignoring PUSH_PROMISE", id, me.actions.recv.max_stream_id()); + log::trace!( + "id ({:?}) > max_stream_id ({:?}), ignoring PUSH_PROMISE", + id, + me.actions.recv.max_stream_id() + ); return Ok(()); } @@ -480,8 +499,8 @@ where } None => { proto_err!(conn: "recv_push_promise: initiating stream is in an invalid state"); - return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)) - }, + return Err(RecvError::Connection(Reason::PROTOCOL_ERROR)); + } }; // TODO: Streams in the reserved states do not count towards the concurrency @@ -495,7 +514,12 @@ where // // If `None` is returned, then the stream is being refused. There is no // further work to be done. - if me.actions.recv.open(promised_id, Open::PushPromise, &mut me.counts)?.is_none() { + if me + .actions + .recv + .open(promised_id, Open::PushPromise, &mut me.counts)? + .is_none() + { return Ok(()); } @@ -507,21 +531,26 @@ where Stream::new( promised_id, me.actions.send.init_window_sz(), - me.actions.recv.init_window_sz()) + me.actions.recv.init_window_sz(), + ) }); let actions = &mut me.actions; me.counts.transition(stream, |counts, stream| { - let stream_valid = - actions.recv.recv_push_promise(frame, stream); + let stream_valid = actions.recv.recv_push_promise(frame, stream); match stream_valid { - Ok(()) => - Ok(Some(stream.key())), + Ok(()) => Ok(Some(stream.key())), _ => { let mut send_buffer = self.send_buffer.inner.lock().unwrap(); - actions.reset_on_recv_stream_err(&mut *send_buffer, stream, counts, stream_valid) + actions + .reset_on_recv_stream_err( + &mut *send_buffer, + stream, + counts, + stream_valid, + ) .map(|()| None) } } @@ -549,7 +578,11 @@ where me.refs += 1; key.map(|key| { let stream = &mut me.store.resolve(key); - log::trace!("next_incoming; id={:?}, state={:?}", stream.id, stream.state); + log::trace!( + "next_incoming; id={:?}, state={:?}", + stream.id, + stream.state + ); StreamRef { opaque: OpaqueStreamRef::new(self.inner.clone(), stream), send_buffer: self.send_buffer.clone(), @@ -559,25 +592,33 @@ where pub fn send_pending_refusal( &mut self, + cx: &mut Context, dst: &mut Codec>, - ) -> Poll<(), io::Error> + ) -> Poll> where - T: AsyncWrite, + T: AsyncWrite + Unpin, + B: Unpin, { let mut me = self.inner.lock().unwrap(); let me = &mut *me; - me.actions.recv.send_pending_refusal(dst) + me.actions.recv.send_pending_refusal(cx, dst) } pub fn clear_expired_reset_streams(&mut self) { let mut me = self.inner.lock().unwrap(); let me = &mut *me; - me.actions.recv.clear_expired_reset_streams(&mut me.store, &mut me.counts); + me.actions + .recv + .clear_expired_reset_streams(&mut me.store, &mut me.counts); } - pub fn poll_complete(&mut self, dst: &mut Codec>) -> Poll<(), io::Error> + pub fn poll_complete( + &mut self, + cx: &mut Context, + dst: &mut Codec>, + ) -> Poll> where - T: AsyncWrite, + T: AsyncWrite + Unpin, { let mut me = self.inner.lock().unwrap(); let me = &mut *me; @@ -589,20 +630,21 @@ where // // TODO: It would probably be better to interleave updates w/ data // frames. - try_ready!(me.actions.recv.poll_complete(&mut me.store, &mut me.counts, dst)); + ready!(me + .actions + .recv + .poll_complete(cx, &mut me.store, &mut me.counts, dst))?; // Send any other pending frames - try_ready!(me.actions.send.poll_complete( - send_buffer, - &mut me.store, - &mut me.counts, - dst - )); + ready!(me + .actions + .send + .poll_complete(cx, send_buffer, &mut me.store, &mut me.counts, dst))?; // Nothing else to do, track the task - me.actions.task = Some(task::current()); + me.actions.task = Some(cx.waker().clone()); - Ok(().into()) + Poll::Ready(Ok(())) } pub fn apply_remote_settings(&mut self, frame: &frame::Settings) -> Result<(), RecvError> { @@ -615,7 +657,12 @@ where me.counts.apply_remote_settings(frame); me.actions.send.apply_remote_settings( - frame, send_buffer, &mut me.store, &mut me.counts, &mut me.actions.task) + frame, + send_buffer, + &mut me.store, + &mut me.counts, + &mut me.actions.task, + ) } pub fn send_request( @@ -624,8 +671,8 @@ where end_of_stream: bool, pending: Option<&OpaqueStreamRef>, ) -> Result, SendError> { - use http::Method; use super::stream::ContentLength; + use http::Method; // TODO: There is a hazard with assigning a stream ID before the // prioritize layer. If prioritization reorders new streams, this @@ -671,8 +718,7 @@ where } // Convert the message - let headers = client::Peer::convert_send_message( - stream_id, request, end_of_stream)?; + let headers = client::Peer::convert_send_message(stream_id, request, end_of_stream)?; let mut stream = me.store.insert(stream.id, stream); @@ -701,10 +747,7 @@ where me.refs += 1; Ok(StreamRef { - opaque: OpaqueStreamRef::new( - self.inner.clone(), - &mut stream, - ), + opaque: OpaqueStreamRef::new(self.inner.clone(), &mut stream), send_buffer: self.send_buffer.clone(), }) } @@ -719,13 +762,14 @@ where let stream = Stream::new(id, 0, 0); e.insert(stream) - }, + } }; let stream = me.store.resolve(key); let mut send_buffer = self.send_buffer.inner.lock().unwrap(); let send_buffer = &mut *send_buffer; - me.actions.send_reset(stream, reason, &mut me.counts, send_buffer); + me.actions + .send_reset(stream, reason, &mut me.counts, send_buffer); } pub fn send_go_away(&mut self, last_processed_id: StreamId) { @@ -740,7 +784,11 @@ impl Streams where B: Buf, { - pub fn poll_pending_open(&mut self, pending: Option<&OpaqueStreamRef>) -> Poll<(), crate::Error> { + pub fn poll_pending_open( + &mut self, + cx: &Context, + pending: Option<&OpaqueStreamRef>, + ) -> Poll> { let mut me = self.inner.lock().unwrap(); let me = &mut *me; @@ -751,11 +799,11 @@ where let mut stream = me.store.resolve(pending.key); log::trace!("poll_pending_open; stream = {:?}", stream.is_pending_open); if stream.is_pending_open { - stream.wait_send(); - return Ok(Async::NotReady); + stream.wait_send(cx); + return Poll::Pending; } } - Ok(().into()) + Poll::Ready(Ok(())) } } @@ -845,7 +893,6 @@ where } } - // ===== impl StreamRef ===== impl StreamRef { @@ -867,12 +914,9 @@ impl StreamRef { frame.set_end_stream(end_stream); // Send the data frame - actions.send.send_data( - frame, - send_buffer, - stream, - counts, - &mut actions.task) + actions + .send + .send_data(frame, send_buffer, stream, counts, &mut actions.task) }) } @@ -890,8 +934,9 @@ impl StreamRef { let frame = frame::Headers::trailers(stream.id, trailers); // Send the trailers frame - actions.send.send_trailers( - frame, send_buffer, stream, counts, &mut actions.task) + actions + .send + .send_trailers(frame, send_buffer, stream, counts, &mut actions.task) }) } @@ -903,7 +948,8 @@ impl StreamRef { let mut send_buffer = self.send_buffer.inner.lock().unwrap(); let send_buffer = &mut *send_buffer; - me.actions.send_reset(stream, reason, &mut me.counts, send_buffer); + me.actions + .send_reset(stream, reason, &mut me.counts, send_buffer); } pub fn send_response( @@ -922,8 +968,9 @@ impl StreamRef { me.counts.transition(stream, |counts, stream| { let frame = server::Peer::convert_send_message(stream.id, response, end_of_stream); - actions.send.send_headers( - frame, send_buffer, stream, counts, &mut actions.task) + actions + .send + .send_headers(frame, send_buffer, stream, counts, &mut actions.task) }) } @@ -955,7 +1002,9 @@ impl StreamRef { let mut stream = me.store.resolve(self.opaque.key); - me.actions.send.reserve_capacity(capacity, &mut stream, &mut me.counts) + me.actions + .send + .reserve_capacity(capacity, &mut stream, &mut me.counts) } /// Returns the stream's current send capacity. @@ -969,28 +1018,35 @@ impl StreamRef { } /// Request to be notified when the stream's capacity increases - pub fn poll_capacity(&mut self) -> Poll, UserError> { + pub fn poll_capacity(&mut self, cx: &Context) -> Poll>> { let mut me = self.opaque.inner.lock().unwrap(); let me = &mut *me; let mut stream = me.store.resolve(self.opaque.key); - me.actions.send.poll_capacity(&mut stream) + me.actions.send.poll_capacity(cx, &mut stream) } /// Request to be notified for if a `RST_STREAM` is received for this stream. - pub(crate) fn poll_reset(&mut self, mode: proto::PollReset) -> Poll { + pub(crate) fn poll_reset( + &mut self, + cx: &Context, + mode: proto::PollReset, + ) -> Poll> { let mut me = self.opaque.inner.lock().unwrap(); let me = &mut *me; let mut stream = me.store.resolve(self.opaque.key); - me.actions.send.poll_reset(&mut stream, mode) + me.actions + .send + .poll_reset(cx, &mut stream, mode) .map_err(From::from) } pub fn clone_to_opaque(&self) -> OpaqueStreamRef - where B: 'static, + where + B: 'static, { self.opaque.clone() } @@ -1015,35 +1071,37 @@ impl OpaqueStreamRef { fn new(inner: Arc>, stream: &mut store::Ptr) -> OpaqueStreamRef { stream.ref_inc(); OpaqueStreamRef { - inner, key: stream.key() + inner, + key: stream.key(), } } /// Called by a client to check for a received response. - pub fn poll_response(&mut self) -> Poll, proto::Error> { + pub fn poll_response(&mut self, cx: &Context) -> Poll, proto::Error>> { let mut me = self.inner.lock().unwrap(); let me = &mut *me; let mut stream = me.store.resolve(self.key); - me.actions.recv.poll_response(&mut stream) + me.actions.recv.poll_response(cx, &mut stream) } /// Called by a client to check for a pushed request. pub fn poll_pushed( - &mut self - ) -> Poll, OpaqueStreamRef)>, proto::Error> { + &mut self, + cx: &Context, + ) -> Poll, OpaqueStreamRef), proto::Error>>> { let mut me = self.inner.lock().unwrap(); let me = &mut *me; - let res = { - let mut stream = me.store.resolve(self.key); - try_ready!(me.actions.recv.poll_pushed(&mut stream)) - }; - Ok(Async::Ready(res.map(|(h, key)| { - me.refs += 1; - let opaque_ref = - OpaqueStreamRef::new(self.inner.clone(), &mut me.store.resolve(key)); - (h, opaque_ref) - }))) + let mut stream = me.store.resolve(self.key); + me.actions + .recv + .poll_pushed(cx, &mut stream) + .map_ok_(|(h, key)| { + me.refs += 1; + let opaque_ref = + OpaqueStreamRef::new(self.inner.clone(), &mut me.store.resolve(key)); + (h, opaque_ref) + }) } pub fn body_is_empty(&self) -> bool { @@ -1064,22 +1122,22 @@ impl OpaqueStreamRef { me.actions.recv.is_end_stream(&stream) } - pub fn poll_data(&mut self) -> Poll, proto::Error> { + pub fn poll_data(&mut self, cx: &Context) -> Poll>> { let mut me = self.inner.lock().unwrap(); let me = &mut *me; let mut stream = me.store.resolve(self.key); - me.actions.recv.poll_data(&mut stream) + me.actions.recv.poll_data(cx, &mut stream) } - pub fn poll_trailers(&mut self) -> Poll, proto::Error> { + pub fn poll_trailers(&mut self, cx: &Context) -> Poll>> { let mut me = self.inner.lock().unwrap(); let me = &mut *me; let mut stream = me.store.resolve(self.key); - me.actions.recv.poll_trailers(&mut stream) + me.actions.recv.poll_trailers(cx, &mut stream) } /// Releases recv capacity back to the peer. This may result in sending @@ -1101,16 +1159,11 @@ impl OpaqueStreamRef { let mut stream = me.store.resolve(self.key); - me.actions - .recv - .clear_recv_buffer(&mut stream); + me.actions.recv.clear_recv_buffer(&mut stream); } pub fn stream_id(&self) -> StreamId { - self.inner.lock() - .unwrap() - .store[self.key] - .id + self.inner.lock().unwrap().store[self.key].id } } @@ -1125,17 +1178,15 @@ impl fmt::Debug for OpaqueStreamRef { .field("stream_id", &stream.id) .field("ref_count", &stream.ref_count) .finish() - }, - Err(Poisoned(_)) => { - fmt.debug_struct("OpaqueStreamRef") - .field("inner", &"") - .finish() - } - Err(WouldBlock) => { - fmt.debug_struct("OpaqueStreamRef") - .field("inner", &"") - .finish() } + Err(Poisoned(_)) => fmt + .debug_struct("OpaqueStreamRef") + .field("inner", &"") + .finish(), + Err(WouldBlock) => fmt + .debug_struct("OpaqueStreamRef") + .field("inner", &"") + .finish(), } } } @@ -1164,12 +1215,14 @@ impl Drop for OpaqueStreamRef { fn drop_stream_ref(inner: &Mutex, key: store::Key) { let mut me = match inner.lock() { Ok(inner) => inner, - Err(_) => if ::std::thread::panicking() { - log::trace!("StreamRef::drop; mutex poisoned"); - return; - } else { - panic!("StreamRef::drop; mutex poisoned"); - }, + Err(_) => { + if ::std::thread::panicking() { + log::trace!("StreamRef::drop; mutex poisoned"); + return; + } else { + panic!("StreamRef::drop; mutex poisoned"); + } + } }; let me = &mut *me; @@ -1189,19 +1242,19 @@ fn drop_stream_ref(inner: &Mutex, key: store::Key) { // (connection) so that it can close properly if stream.ref_count == 0 && stream.is_closed() { if let Some(task) = actions.task.take() { - task.notify(); + task.wake(); } } - me.counts.transition(stream, |counts, stream| { maybe_cancel(stream, actions, counts); if stream.ref_count == 0 { - // Release any recv window back to connection, no one can access // it anymore. - actions.recv.release_closed_capacity(stream, &mut actions.task); + actions + .recv + .release_closed_capacity(stream, &mut actions.task); // We won't be able to reach our push promises anymore let mut ppp = stream.pending_push_promises.take(); @@ -1216,11 +1269,9 @@ fn drop_stream_ref(inner: &Mutex, key: store::Key) { fn maybe_cancel(stream: &mut store::Ptr, actions: &mut Actions, counts: &mut Counts) { if stream.is_canceled_interest() { - actions.send.schedule_implicit_reset( - stream, - Reason::CANCEL, - counts, - &mut actions.task); + actions + .send + .schedule_implicit_reset(stream, Reason::CANCEL, counts, &mut actions.task); actions.recv.enqueue_reset_expiration(stream, counts); } } @@ -1245,8 +1296,8 @@ impl Actions { send_buffer: &mut Buffer>, ) { counts.transition(stream, |counts, stream| { - self.send.send_reset( - reason, send_buffer, stream, counts, &mut self.task); + self.send + .send_reset(reason, send_buffer, stream, counts, &mut self.task); self.recv.enqueue_reset_expiration(stream, counts); // if a RecvStream is parked, ensure it's notified stream.notify_recv(); @@ -1260,12 +1311,10 @@ impl Actions { counts: &mut Counts, res: Result<(), RecvError>, ) -> Result<(), RecvError> { - if let Err(RecvError::Stream { - reason, .. - }) = res - { + if let Err(RecvError::Stream { reason, .. }) = res { // Reset the stream. - self.send.send_reset(reason, buffer, stream, counts, &mut self.task); + self.send + .send_reset(reason, buffer, stream, counts, &mut self.task); Ok(()) } else { res @@ -1308,11 +1357,7 @@ impl Actions { } } - fn clear_queues(&mut self, - clear_pending_accept: bool, - store: &mut Store, - counts: &mut Counts) - { + fn clear_queues(&mut self, clear_pending_accept: bool, store: &mut Store, counts: &mut Counts) { self.recv.clear_queues(clear_pending_accept, store, counts); self.send.clear_queues(store, counts); } diff --git a/src/server.rs b/src/server.rs index 3f6d8255a..dc9d257e3 100644 --- a/src/server.rs +++ b/src/server.rs @@ -64,50 +64,45 @@ //! will use the HTTP/2.0 protocol without prior negotiation. //! //! ```rust -//! use futures::{Future, Stream}; -//! # use futures::future::ok; +//! #![feature(async_await)] +//! use futures::StreamExt; //! use h2::server; //! use http::{Response, StatusCode}; //! use tokio::net::TcpListener; //! -//! pub fn main () { +//! #[tokio::main] +//! pub async fn main () { //! let addr = "127.0.0.1:5928".parse().unwrap(); //! let listener = TcpListener::bind(&addr,).unwrap(); //! -//! tokio::run({ -//! // Accept all incoming TCP connections. -//! listener.incoming().for_each(move |socket| { -//! // Spawn a new task to process each connection. -//! tokio::spawn({ -//! // Start the HTTP/2.0 connection handshake -//! server::handshake(socket) -//! .and_then(|h2| { -//! // Accept all inbound HTTP/2.0 streams sent over the -//! // connection. -//! h2.for_each(|(request, mut respond)| { -//! println!("Received request: {:?}", request); +//! // Accept all incoming TCP connections. +//! let mut incoming = listener.incoming(); +//! # futures::future::select(Box::pin(async { +//! while let Some(socket) = incoming.next().await { +//! // Spawn a new task to process each connection. +//! tokio::spawn(async { +//! // Start the HTTP/2.0 connection handshake +//! let mut h2 = server::handshake(socket.unwrap()).await.unwrap(); +//! // Accept all inbound HTTP/2.0 streams sent over the +//! // connection. +//! while let Some(request) = h2.next().await { +//! let (request, mut respond) = request.unwrap(); +//! println!("Received request: {:?}", request); //! -//! // Build a response with no body -//! let response = Response::builder() -//! .status(StatusCode::OK) -//! .body(()) -//! .unwrap(); +//! // Build a response with no body +//! let response = Response::builder() +//! .status(StatusCode::OK) +//! .body(()) +//! .unwrap(); //! -//! // Send the response back to the client -//! respond.send_response(response, true) -//! .unwrap(); +//! // Send the response back to the client +//! respond.send_response(response, true) +//! .unwrap(); +//! } //! -//! Ok(()) -//! }) -//! }) -//! .map_err(|e| panic!("unexpected error = {:?}", e)) -//! }); -//! -//! Ok(()) -//! }) -//! .map_err(|e| panic!("failed to run HTTP/2.0 server: {:?}", e)) -//! # .select(ok(())).map(|_|()).map_err(|_|()) -//! }); +//! }); +//! } +//! # }), Box::pin(async {})).await; //! } //! ``` //! @@ -124,17 +119,20 @@ //! [`SendStream`]: ../struct.SendStream.html //! [`TcpListener`]: https://docs.rs/tokio-core/0.1/tokio_core/net/struct.TcpListener.html -use crate::{SendStream, RecvStream, ReleaseCapacity, PingPong}; use crate::codec::{Codec, RecvError}; use crate::frame::{self, Pseudo, Reason, Settings, StreamId}; use crate::proto::{self, Config, Prioritized}; +use crate::{PingPong, RecvStream, ReleaseCapacity, SendStream}; use bytes::{Buf, Bytes, IntoBuf}; -use futures::{self, Async, Future, Poll, try_ready}; +use futures::ready; use http::{HeaderMap, Request, Response}; -use std::{convert, fmt, io, mem}; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Duration; -use tokio_io::{AsyncRead, AsyncWrite, try_nb}; +use std::{convert, fmt, io, mem}; +use tokio_io::{AsyncRead, AsyncWrite}; /// In progress HTTP/2.0 connection handshake future. /// @@ -155,7 +153,7 @@ pub struct Handshake { /// The config to pass to Connection::new after handshake succeeds. builder: Builder, /// The current state of the handshake. - state: Handshaking + state: Handshaking, } /// Accepts inbound HTTP/2.0 streams on a connection. @@ -179,21 +177,19 @@ pub struct Handshake { /// # Examples /// /// ``` -/// # use futures::{Future, Stream}; +/// # #![feature(async_await)] +/// # use futures::StreamExt; /// # use tokio_io::*; /// # use h2::server; /// # use h2::server::*; /// # -/// # fn doc(my_io: T) { -/// server::handshake(my_io) -/// .and_then(|server| { -/// server.for_each(|(request, respond)| { -/// // Process the request and send the response back to the client -/// // using `respond`. -/// # Ok(()) -/// }) -/// }) -/// # .wait().unwrap(); +/// # async fn doc(my_io: T) { +/// let mut server = server::handshake(my_io).await.unwrap(); +/// while let Some(request) = server.next().await { +/// let (request, respond) = request.unwrap(); +/// // Process the request and send the response back to the client +/// // using `respond`. +/// } /// # } /// # /// # pub fn main() {} @@ -224,7 +220,7 @@ pub struct Connection { /// # use tokio_io::*; /// # use h2::server::*; /// # -/// # fn doc(my_io: T) +/// # fn doc(my_io: T) /// # -> Handshake /// # { /// // `server_fut` is a future representing the completion of the HTTP/2.0 @@ -318,26 +314,23 @@ const PREFACE: [u8; 24] = *b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; /// # Examples /// /// ``` +/// # #![feature(async_await)] /// # use tokio_io::*; -/// # use futures::*; /// # use h2::server; /// # use h2::server::*; /// # -/// # fn doc(my_io: T) +/// # async fn doc(my_io: T) /// # { -/// server::handshake(my_io) -/// .and_then(|connection| { -/// // The HTTP/2.0 handshake has completed, now use `connection` to -/// // accept inbound HTTP/2.0 streams. -/// # Ok(()) -/// }) -/// # .wait().unwrap(); +/// let connection = server::handshake(my_io).await.unwrap(); +/// // The HTTP/2.0 handshake has completed, now use `connection` to +/// // accept inbound HTTP/2.0 streams. /// # } /// # /// # pub fn main() {} /// ``` pub fn handshake(io: T) -> Handshake -where T: AsyncRead + AsyncWrite, +where + T: AsyncRead + AsyncWrite + Unpin, { Builder::new().handshake(io) } @@ -346,8 +339,9 @@ where T: AsyncRead + AsyncWrite, impl Connection where - T: AsyncRead + AsyncWrite, - B: IntoBuf, + T: AsyncRead + AsyncWrite + Unpin, + B: IntoBuf + Unpin, + B::Buf: Unpin, { fn handshake2(io: T, builder: Builder) -> Handshake { // Create the codec. @@ -407,11 +401,14 @@ where /// [`poll`]: struct.Connection.html#method.poll /// [`RecvStream`]: ../struct.RecvStream.html /// [`SendStream`]: ../struct.SendStream.html - pub fn poll_close(&mut self) -> Poll<(), crate::Error> { - self.connection.poll().map_err(Into::into) + pub fn poll_close(&mut self, cx: &mut Context) -> Poll> { + self.connection.poll(cx).map_err(Into::into) } - #[deprecated(note="use abrupt_shutdown or graceful_shutdown instead", since="0.1.4")] + #[deprecated( + note = "use abrupt_shutdown or graceful_shutdown instead", + since = "0.1.4" + )] #[doc(hidden)] pub fn close_connection(&mut self) { self.graceful_shutdown(); @@ -453,31 +450,28 @@ where /// /// This may only be called once. Calling multiple times will return `None`. pub fn ping_pong(&mut self) -> Option { - self.connection - .take_user_pings() - .map(PingPong::new) + self.connection.take_user_pings().map(PingPong::new) } } impl futures::Stream for Connection where - T: AsyncRead + AsyncWrite, - B: IntoBuf, - B::Buf: 'static, + T: AsyncRead + AsyncWrite + Unpin, + B: IntoBuf + Unpin, + B::Buf: Unpin + 'static, { - type Item = (Request, SendResponse); - type Error = crate::Error; - - fn poll(&mut self) -> Poll, crate::Error> { - // Always try to advance the internal state. Getting NotReady also is - // needed to allow this function to return NotReady. - match self.poll_close()? { - Async::Ready(_) => { + type Item = Result<(Request, SendResponse), crate::Error>; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // Always try to advance the internal state. Getting Pending also is + // needed to allow this function to return Pending. + match self.poll_close(cx)? { + Poll::Ready(_) => { // If the socket is closed, don't return anything // TODO: drop any pending streams - return Ok(None.into()); - }, - _ => {}, + return Poll::Ready(None); + } + _ => {} } if let Some(inner) = self.connection.next_incoming() { @@ -488,10 +482,10 @@ where let request = Request::from_parts(head, body); let respond = SendResponse { inner }; - return Ok(Some((request, respond)).into()); + return Poll::Ready(Some(Ok((request, respond)))); } - Ok(Async::NotReady) + Poll::Pending } } @@ -522,7 +516,7 @@ impl Builder { /// # use tokio_io::*; /// # use h2::server::*; /// # - /// # fn doc(my_io: T) + /// # fn doc(my_io: T) /// # -> Handshake /// # { /// // `server_fut` is a future representing the completion of the HTTP/2.0 @@ -561,7 +555,7 @@ impl Builder { /// # use tokio_io::*; /// # use h2::server::*; /// # - /// # fn doc(my_io: T) + /// # fn doc(my_io: T) /// # -> Handshake /// # { /// // `server_fut` is a future representing the completion of the HTTP/2.0 @@ -595,7 +589,7 @@ impl Builder { /// # use tokio_io::*; /// # use h2::server::*; /// # - /// # fn doc(my_io: T) + /// # fn doc(my_io: T) /// # -> Handshake /// # { /// // `server_fut` is a future representing the completion of the HTTP/2.0 @@ -628,7 +622,7 @@ impl Builder { /// # use tokio_io::*; /// # use h2::server::*; /// # - /// # fn doc(my_io: T) + /// # fn doc(my_io: T) /// # -> Handshake /// # { /// // `server_fut` is a future representing the completion of the HTTP/2.0 @@ -667,7 +661,7 @@ impl Builder { /// # use tokio_io::*; /// # use h2::server::*; /// # - /// # fn doc(my_io: T) + /// # fn doc(my_io: T) /// # -> Handshake /// # { /// // `server_fut` is a future representing the completion of the HTTP/2.0 @@ -715,7 +709,7 @@ impl Builder { /// # use tokio_io::*; /// # use h2::server::*; /// # - /// # fn doc(my_io: T) + /// # fn doc(my_io: T) /// # -> Handshake /// # { /// // `server_fut` is a future representing the completion of the HTTP/2.0 @@ -761,7 +755,7 @@ impl Builder { /// # use tokio_io::*; /// # use h2::server::*; /// # - /// # fn doc(my_io: T) + /// # fn doc(my_io: T) /// # -> Handshake /// # { /// // `server_fut` is a future representing the completion of the HTTP/2.0 @@ -808,7 +802,7 @@ impl Builder { /// # use h2::server::*; /// # use std::time::Duration; /// # - /// # fn doc(my_io: T) + /// # fn doc(my_io: T) /// # -> Handshake /// # { /// // `server_fut` is a future representing the completion of the HTTP/2.0 @@ -850,7 +844,7 @@ impl Builder { /// # use tokio_io::*; /// # use h2::server::*; /// # - /// # fn doc(my_io: T) + /// # fn doc(my_io: T) /// # -> Handshake /// # { /// // `server_fut` is a future representing the completion of the HTTP/2.0 @@ -870,7 +864,7 @@ impl Builder { /// # use tokio_io::*; /// # use h2::server::*; /// # - /// # fn doc(my_io: T) + /// # fn doc(my_io: T) /// # -> Handshake /// # { /// // `server_fut` is a future representing the completion of the HTTP/2.0 @@ -884,9 +878,9 @@ impl Builder { /// ``` pub fn handshake(&self, io: T) -> Handshake where - T: AsyncRead + AsyncWrite, - B: IntoBuf, - B::Buf: 'static, + T: AsyncRead + AsyncWrite + Unpin, + B: IntoBuf + Unpin, + B::Buf: Unpin + 'static, { Connection::handshake2(io, self.clone()) } @@ -949,7 +943,7 @@ impl SendResponse { /// Polls to be notified when the client resets this stream. /// - /// If stream is still open, this returns `Ok(Async::NotReady)`, and + /// If stream is still open, this returns `Poll::Pending`, and /// registers the task to be notified if a `RST_STREAM` is received. /// /// If a `RST_STREAM` frame is received for this stream, calling this @@ -959,8 +953,8 @@ impl SendResponse { /// /// Calling this method after having called `send_response` will return /// a user error. - pub fn poll_reset(&mut self) -> Poll { - self.inner.poll_reset(proto::PollReset::AwaitingHeaders) + pub fn poll_reset(&mut self, cx: &mut Context) -> Poll> { + self.inner.poll_reset(cx, proto::PollReset::AwaitingHeaders) } /// Returns the stream ID of the response stream. @@ -979,26 +973,23 @@ impl SendResponse { impl Flush { fn new(codec: Codec) -> Self { - Flush { - codec: Some(codec), - } + Flush { codec: Some(codec) } } } impl Future for Flush where - T: AsyncWrite, - B: Buf, + T: AsyncWrite + Unpin, + B: Buf + Unpin, { - type Item = Codec; - type Error = crate::Error; + type Output = Result, crate::Error>; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { // Flush the codec - try_ready!(self.codec.as_mut().unwrap().flush()); + ready!(self.codec.as_mut().unwrap().flush(cx))?; // Return the codec - Ok(Async::Ready(self.codec.take().unwrap())) + Poll::Ready(Ok(self.codec.take().unwrap())) } } @@ -1017,49 +1008,50 @@ impl ReadPreface { impl Future for ReadPreface where - T: AsyncRead, - B: Buf, + T: AsyncRead + Unpin, + B: Buf + Unpin, { - type Item = Codec; - type Error = crate::Error; + type Output = Result, crate::Error>; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut buf = [0; 24]; let mut rem = PREFACE.len() - self.pos; while rem > 0 { - let n = try_nb!(self.inner_mut().read(&mut buf[..rem])); + let n = ready!(Pin::new(self.inner_mut()).poll_read(cx, &mut buf[..rem]))?; if n == 0 { - return Err(io::Error::new( + return Poll::Ready(Err(io::Error::new( io::ErrorKind::ConnectionReset, "connection closed unexpectedly", - ).into()); + ) + .into())); } if PREFACE[self.pos..self.pos + n] != buf[..n] { proto_err!(conn: "read_preface: invalid preface"); // TODO: Should this just write the GO_AWAY frame directly? - return Err(Reason::PROTOCOL_ERROR.into()); + return Poll::Ready(Err(Reason::PROTOCOL_ERROR.into())); } self.pos += n; rem -= n; // TODO test } - Ok(Async::Ready(self.codec.take().unwrap())) + Poll::Ready(Ok(self.codec.take().unwrap())) } } // ===== impl Handshake ===== impl Future for Handshake - where T: AsyncRead + AsyncWrite, - B: IntoBuf, +where + T: AsyncRead + AsyncWrite + Unpin, + B: IntoBuf + Unpin, + B::Buf: Unpin, { - type Item = Connection; - type Error = crate::Error; + type Output = Result, crate::Error>; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { log::trace!("Handshake::poll(); state={:?};", self.state); use crate::server::Handshaking::*; @@ -1067,12 +1059,12 @@ impl Future for Handshake // We're currently flushing a pending SETTINGS frame. Poll the // flush future, and, if it's completed, advance our state to wait // for the client preface. - let codec = match flush.poll()? { - Async::NotReady => { - log::trace!("Handshake::poll(); flush.poll()=NotReady"); - return Ok(Async::NotReady); - }, - Async::Ready(flushed) => { + let codec = match Pin::new(flush).poll(cx)? { + Poll::Pending => { + log::trace!("Handshake::poll(); flush.poll()=Pending"); + return Poll::Pending; + } + Poll::Ready(flushed) => { log::trace!("Handshake::poll(); flush.poll()=Ready"); flushed } @@ -1089,38 +1081,41 @@ impl Future for Handshake // We're now waiting for the client preface. Poll the `ReadPreface` // future. If it has completed, we will create a `Connection` handle // for the connection. - read.poll() - // Actually creating the `Connection` has to occur outside of this - // `if let` block, because we've borrowed `self` mutably in order - // to poll the state and won't be able to borrow the SETTINGS frame - // as well until we release the borrow for `poll()`. + Pin::new(read).poll(cx) + // Actually creating the `Connection` has to occur outside of this + // `if let` block, because we've borrowed `self` mutably in order + // to poll the state and won't be able to borrow the SETTINGS frame + // as well until we release the borrow for `poll()`. } else { unreachable!("Handshake::poll() state was not advanced completely!") }; - let server = poll?.map(|codec| { - let connection = proto::Connection::new(codec, Config { - next_stream_id: 2.into(), - // Server does not need to locally initiate any streams - initial_max_send_streams: 0, - reset_stream_duration: self.builder.reset_stream_duration, - reset_stream_max: self.builder.reset_stream_max, - settings: self.builder.settings.clone(), - }); + poll?.map(|codec| { + let connection = proto::Connection::new( + codec, + Config { + next_stream_id: 2.into(), + // Server does not need to locally initiate any streams + initial_max_send_streams: 0, + reset_stream_duration: self.builder.reset_stream_duration, + reset_stream_max: self.builder.reset_stream_max, + settings: self.builder.settings.clone(), + }, + ); log::trace!("Handshake::poll(); connection established!"); let mut c = Connection { connection }; if let Some(sz) = self.builder.initial_target_connection_window_size { c.set_target_window_size(sz); } - c - }); - Ok(server) + Ok(c) + }) } } impl fmt::Debug for Handshake - where T: AsyncRead + AsyncWrite + fmt::Debug, - B: fmt::Debug + IntoBuf, +where + T: AsyncRead + AsyncWrite + fmt::Debug, + B: fmt::Debug + IntoBuf, { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { write!(fmt, "server::Handshake") @@ -1131,16 +1126,14 @@ impl Peer { pub fn convert_send_message( id: StreamId, response: Response<()>, - end_of_stream: bool) -> frame::Headers - { + end_of_stream: bool, + ) -> frame::Headers { use http::response::Parts; // Extract the components of the HTTP request let ( Parts { - status, - headers, - .. + status, headers, .. }, _, ) = response.into_parts(); @@ -1172,7 +1165,9 @@ impl proto::Peer for Peer { } fn convert_poll_message( - pseudo: Pseudo, fields: HeaderMap, stream_id: StreamId + pseudo: Pseudo, + fields: HeaderMap, + stream_id: StreamId, ) -> Result { use http::{uri, Version}; @@ -1205,23 +1200,29 @@ impl proto::Peer for Peer { // Convert the URI let mut parts = uri::Parts::default(); - // A request translated from HTTP/1 must not include the :authority // header if let Some(authority) = pseudo.authority { let maybe_authority = uri::Authority::from_shared(authority.clone().into_inner()); - parts.authority = Some(maybe_authority.or_else(|why| malformed!( - "malformed headers: malformed authority ({:?}): {}", authority, why, - ))?); - + parts.authority = Some(maybe_authority.or_else(|why| { + malformed!( + "malformed headers: malformed authority ({:?}): {}", + authority, + why, + ) + })?); } // A :scheme is always required. if let Some(scheme) = pseudo.scheme { let maybe_scheme = uri::Scheme::from_shared(scheme.clone().into_inner()); - let scheme = maybe_scheme.or_else(|why| malformed!( - "malformed headers: malformed scheme ({:?}): {}", scheme, why, - ))?; + let scheme = maybe_scheme.or_else(|why| { + malformed!( + "malformed headers: malformed scheme ({:?}): {}", + scheme, + why, + ) + })?; // It's not possible to build an `Uri` from a scheme and path. So, // after validating is was a valid scheme, we just have to drop it @@ -1240,9 +1241,9 @@ impl proto::Peer for Peer { } let maybe_path = uri::PathAndQuery::from_shared(path.clone().into_inner()); - parts.path_and_query = Some(maybe_path.or_else(|why| malformed!( - "malformed headers: malformed path ({:?}): {}", path, why, - ))?); + parts.path_and_query = Some(maybe_path.or_else(|why| { + malformed!("malformed headers: malformed path ({:?}): {}", path, why,) + })?); } b.uri(parts); @@ -1257,7 +1258,7 @@ impl proto::Peer for Peer { id: stream_id, reason: Reason::PROTOCOL_ERROR, }); - }, + } }; *request.headers_mut() = fields; @@ -1270,18 +1271,15 @@ impl proto::Peer for Peer { impl fmt::Debug for Handshaking where - B: IntoBuf + B: IntoBuf, { - #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { match *self { - Handshaking::Flushing(_) => - write!(f, "Handshaking::Flushing(_)"), - Handshaking::ReadingPreface(_) => - write!(f, "Handshaking::ReadingPreface(_)"), - Handshaking::Empty => - write!(f, "Handshaking::Empty"), + Handshaking::Flushing(_) => write!(f, "Handshaking::Flushing(_)"), + Handshaking::ReadingPreface(_) => write!(f, "Handshaking::ReadingPreface(_)"), + Handshaking::Empty => write!(f, "Handshaking::Empty"), } - } } @@ -1290,18 +1288,19 @@ where T: AsyncRead + AsyncWrite, B: IntoBuf, { - #[inline] fn from(flush: Flush>) -> Self { + #[inline] + fn from(flush: Flush>) -> Self { Handshaking::Flushing(flush) } } -impl convert::From>> for - Handshaking +impl convert::From>> for Handshaking where T: AsyncRead + AsyncWrite, B: IntoBuf, { - #[inline] fn from(read: ReadPreface>) -> Self { + #[inline] + fn from(read: ReadPreface>) -> Self { Handshaking::ReadingPreface(read) } } @@ -1311,7 +1310,8 @@ where T: AsyncRead + AsyncWrite, B: IntoBuf, { - #[inline] fn from(codec: Codec>) -> Self { + #[inline] + fn from(codec: Codec>) -> Self { Handshaking::from(Flush::new(codec)) } } diff --git a/src/share.rs b/src/share.rs index 2d6df0716..4f6bc397d 100644 --- a/src/share.rs +++ b/src/share.rs @@ -3,10 +3,13 @@ use crate::frame::Reason; use crate::proto::{self, WindowSize}; use bytes::{Bytes, IntoBuf}; -use futures::{self, Poll, Async, try_ready}; -use http::{HeaderMap}; +use http::HeaderMap; +use crate::PollExt; +use futures::ready; use std::fmt; +use std::pin::Pin; +use std::task::{Context, Poll}; /// Sends the body stream and trailers to the remote peer. /// @@ -264,11 +267,12 @@ impl SendStream { /// is sent. For example: /// /// ```rust + /// # #![feature(async_await)] /// # use h2::*; - /// # fn doc(mut send_stream: SendStream<&'static [u8]>) { + /// # async fn doc(mut send_stream: SendStream<&'static [u8]>) { /// send_stream.reserve_capacity(100); /// - /// let capacity = send_stream.poll_capacity(); + /// let capacity = futures::future::poll_fn(|cx| send_stream.poll_capacity(cx)).await; /// // capacity == 5; /// /// send_stream.send_data(b"hello", false).unwrap(); @@ -309,9 +313,11 @@ impl SendStream { /// amount of assigned capacity at that point in time. It is also possible /// that `n` is lower than the previous call if, since then, the caller has /// sent data. - pub fn poll_capacity(&mut self) -> Poll, crate::Error> { - let res = try_ready!(self.inner.poll_capacity()); - Ok(Async::Ready(res.map(|v| v as usize))) + pub fn poll_capacity(&mut self, cx: &mut Context) -> Poll>> { + self.inner + .poll_capacity(cx) + .map_ok_(|w| w as usize) + .map_err_(Into::into) } /// Sends a single data frame to the remote peer. @@ -356,7 +362,7 @@ impl SendStream { /// Polls to be notified when the client resets this stream. /// - /// If stream is still open, this returns `Ok(Async::NotReady)`, and + /// If stream is still open, this returns `Poll::Pending`, and /// registers the task to be notified if a `RST_STREAM` is received. /// /// If a `RST_STREAM` frame is received for this stream, calling this @@ -366,8 +372,8 @@ impl SendStream { /// /// If connection sees an error, this returns that error instead of a /// `Reason`. - pub fn poll_reset(&mut self) -> Poll { - self.inner.poll_reset(proto::PollReset::Streaming) + pub fn poll_reset(&mut self, cx: &mut Context) -> Poll> { + self.inner.poll_reset(cx, proto::PollReset::Streaming) } /// Returns the stream ID of this `SendStream`. @@ -417,8 +423,11 @@ impl RecvStream { } /// Returns received trailers. - pub fn poll_trailers(&mut self) -> Poll, crate::Error> { - self.inner.inner.poll_trailers().map_err(Into::into) + pub fn poll_trailers( + &mut self, + cx: &mut Context, + ) -> Poll>> { + self.inner.inner.poll_trailers(cx).map_err_(Into::into) } /// Returns the stream ID of this stream. @@ -432,11 +441,10 @@ impl RecvStream { } impl futures::Stream for RecvStream { - type Item = Bytes; - type Error = crate::Error; + type Item = Result; - fn poll(&mut self) -> Poll, Self::Error> { - self.inner.inner.poll_data().map_err(Into::into) + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.inner.poll_data(cx).map_err_(Into::into) } } @@ -514,9 +522,7 @@ impl Clone for ReleaseCapacity { impl PingPong { pub(crate) fn new(inner: proto::UserPings) -> Self { - PingPong { - inner, - } + PingPong { inner } } /// Send a `PING` frame to the peer. @@ -540,12 +546,10 @@ impl PingPong { // just drop it. drop(ping); - self.inner - .send_ping() - .map_err(|err| match err { - Some(err) => err.into(), - None => UserError::SendPingWhilePending.into() - }) + self.inner.send_ping().map_err(|err| match err { + Some(err) => err.into(), + None => UserError::SendPingWhilePending.into(), + }) } /// Polls for the acknowledgement of a previously [sent][] `PING` frame. @@ -553,8 +557,8 @@ impl PingPong { /// # Example /// /// ``` - /// # use futures::Future; - /// # fn doc(mut ping_pong: h2::PingPong) { + /// # #![feature(async_await)] + /// # async fn doc(mut ping_pong: h2::PingPong) { /// // let mut ping_pong = ... /// /// // First, send a PING. @@ -563,26 +567,23 @@ impl PingPong { /// .unwrap(); /// /// // And then wait for the PONG. - /// futures::future::poll_fn(move || { - /// ping_pong.poll_pong() - /// }).wait().unwrap(); + /// futures::future::poll_fn(move |cx| { + /// ping_pong.poll_pong(cx) + /// }).await.unwrap(); /// # } /// # fn main() {} /// ``` /// /// [sent]: struct.PingPong.html#method.send_ping - pub fn poll_pong(&mut self) -> Poll { - try_ready!(self.inner.poll_pong()); - Ok(Async::Ready(Pong { - _p: (), - })) + pub fn poll_pong(&mut self, cx: &mut Context) -> Poll> { + ready!(self.inner.poll_pong(cx))?; + Poll::Ready(Ok(Pong { _p: () })) } } impl fmt::Debug for PingPong { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("PingPong") - .finish() + fmt.debug_struct("PingPong").finish() } } @@ -595,16 +596,13 @@ impl Ping { /// /// [`PingPong`]: struct.PingPong.html pub fn opaque() -> Ping { - Ping { - _p: (), - } + Ping { _p: () } } } impl fmt::Debug for Ping { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("Ping") - .finish() + fmt.debug_struct("Ping").finish() } } @@ -612,7 +610,6 @@ impl fmt::Debug for Ping { impl fmt::Debug for Pong { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("Pong") - .finish() + fmt.debug_struct("Pong").finish() } } diff --git a/tests/h2-fuzz/Cargo.toml b/tests/h2-fuzz/Cargo.toml index be616d6b0..f9b9f119a 100644 --- a/tests/h2-fuzz/Cargo.toml +++ b/tests/h2-fuzz/Cargo.toml @@ -1,15 +1,15 @@ -[package] -name = "h2-fuzz" -version = "0.0.0" -publish = false -license = "MIT" -edition = "2018" - -[dependencies] -h2 = { path = "../.." } - -env_logger = { version = "0.5.3", default-features = false } -futures = "0.1.21" -honggfuzz = "0.5" -http = "0.1.3" -tokio-io = "0.1.4" +[package] +name = "h2-fuzz" +version = "0.0.0" +publish = false +license = "MIT" +edition = "2018" + +[dependencies] +h2 = { path = "../.." } + +env_logger = { version = "0.5.3", default-features = false } +futures-preview = "0.3.0-alpha.17" +honggfuzz = "0.5" +http = "0.1.3" +tokio = "0.2.0-alpha.1" diff --git a/tests/h2-fuzz/src/main.rs b/tests/h2-fuzz/src/main.rs index a36e4bfe0..6f4e80231 100644 --- a/tests/h2-fuzz/src/main.rs +++ b/tests/h2-fuzz/src/main.rs @@ -1,154 +1,133 @@ -use futures::prelude::*; -use futures::{executor, future, task}; -use http::{Method, Request}; -use std::cell::Cell; -use std::io::{self, Read, Write}; -use std::sync::Arc; -use tokio_io::{AsyncRead, AsyncWrite}; -use futures::stream::futures_unordered::FuturesUnordered; - -struct MockIo<'a> { - input: &'a [u8], -} - -impl<'a> MockIo<'a> { - fn next_byte(&mut self) -> Option { - if let Some(&c) = self.input.first() { - self.input = &self.input[1..]; - Some(c) - } else { - None - } - } - - fn next_u32(&mut self) -> u32 { - (self.next_byte().unwrap_or(0) as u32) << 8 | self.next_byte().unwrap_or(0) as u32 - } -} - -impl<'a> Read for MockIo<'a> { - fn read(&mut self, buf: &mut [u8]) -> Result { - let mut len = self.next_u32() as usize; - if self.input.is_empty() { - Ok(0) - } else if len == 0 { - task::current().notify(); - Err(io::ErrorKind::WouldBlock.into()) - } else { - if len > self.input.len() { - len = self.input.len(); - } - - if len > buf.len() { - len = buf.len(); - } - buf[0..len].copy_from_slice(&self.input[0..len]); - self.input = &self.input[len..]; - Ok(len) - } - } -} - -impl<'a> AsyncRead for MockIo<'a> { - unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [u8]) -> bool { - false - } -} - -impl<'a> Write for MockIo<'a> { - fn write(&mut self, buf: &[u8]) -> Result { - let len = std::cmp::min(self.next_u32() as usize, buf.len()); - if len == 0 { - if self.input.is_empty() { - Err(io::ErrorKind::BrokenPipe.into()) - } else { - task::current().notify(); - Err(io::ErrorKind::WouldBlock.into()) - } - } else { - Ok(len) - } - } - - fn flush(&mut self) -> Result<(), io::Error> { - Ok(()) - } -} - -impl<'a> AsyncWrite for MockIo<'a> { - fn shutdown(&mut self) -> Poll<(), io::Error> { - Ok(Async::Ready(())) - } -} - -struct MockNotify { - notified: Cell, -} - -unsafe impl Sync for MockNotify {} - -impl executor::Notify for MockNotify { - fn notify(&self, _id: usize) { - self.notified.set(true); - } -} - -impl MockNotify { - fn take_notify(&self) -> bool { - self.notified.replace(false) - } -} - -fn run(script: &[u8]) -> Result<(), h2::Error> { - let notify = Arc::new(MockNotify { - notified: Cell::new(false), - }); - let notify_handle: executor::NotifyHandle = notify.clone().into(); - let io = MockIo { input: script }; - let (mut h2, mut connection) = h2::client::handshake(io).wait()?; - let mut futs = FuturesUnordered::new(); - let future = future::poll_fn(|| { - if let Async::Ready(()) = connection.poll()? { - return Ok(Async::Ready(())); - } - while futs.len() < 128 { - if h2.poll_ready()?.is_not_ready() { - break; - } - let request = Request::builder() - .method(Method::POST) - .uri("https://example.com/") - .body(()) - .unwrap(); - let (resp, mut send) = h2.send_request(request, false)?; - send.send_data(vec![0u8; 32769].into(), true).unwrap(); - drop(send); - futs.push(resp); - } - loop { - match futs.poll() { - Ok(Async::NotReady) | Ok(Async::Ready(None)) => break, - r @ Ok(Async::Ready(_)) | r @ Err(_) => { - eprintln!("{:?}", r); - } - } - } - Ok::<_, h2::Error>(Async::NotReady) - }); - let mut spawn = executor::spawn(future); - loop { - if let Async::Ready(()) = spawn.poll_future_notify(¬ify_handle, 0)? { - return Ok(()); - } - assert!(notify.take_notify()); - } -} - -fn main() { - env_logger::init(); - loop { - honggfuzz::fuzz!(|data: &[u8]| { - eprintln!("{:?}", run(data)); - }); - } -} +#![feature(async_await)] +use futures::future; +use futures::stream::FuturesUnordered; +use futures::Stream; +use http::{Method, Request}; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; + +struct MockIo<'a> { + input: &'a [u8], +} + +impl<'a> MockIo<'a> { + fn next_byte(&mut self) -> Option { + if let Some(&c) = self.input.first() { + self.input = &self.input[1..]; + Some(c) + } else { + None + } + } + + fn next_u32(&mut self) -> u32 { + (self.next_byte().unwrap_or(0) as u32) << 8 | self.next_byte().unwrap_or(0) as u32 + } +} + +impl<'a> AsyncRead for MockIo<'a> { + unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [u8]) -> bool { + false + } + + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let mut len = self.next_u32() as usize; + if self.input.is_empty() { + Poll::Ready(Ok(0)) + } else if len == 0 { + cx.waker().clone().wake(); + Poll::Pending + } else { + if len > self.input.len() { + len = self.input.len(); + } + + if len > buf.len() { + len = buf.len(); + } + buf[0..len].copy_from_slice(&self.input[0..len]); + self.input = &self.input[len..]; + Poll::Ready(Ok(len)) + } + } +} + +impl<'a> AsyncWrite for MockIo<'a> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let len = std::cmp::min(self.next_u32() as usize, buf.len()); + if len == 0 { + if self.input.is_empty() { + Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())) + } else { + cx.waker().clone().wake(); + Poll::Pending + } + } else { + Poll::Ready(Ok(len)) + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +async fn run(script: &[u8]) -> Result<(), h2::Error> { + let io = MockIo { input: script }; + let (mut h2, mut connection) = h2::client::handshake(io).await?; + let mut futs = FuturesUnordered::new(); + let future = future::poll_fn(|cx| { + if let Poll::Ready(()) = Pin::new(&mut connection).poll(cx)? { + return Poll::Ready(Ok::<_, h2::Error>(())); + } + while futs.len() < 128 { + if !h2.poll_ready(cx)?.is_ready() { + break; + } + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + let (resp, mut send) = h2.send_request(request, false)?; + send.send_data(vec![0u8; 32769].into(), true).unwrap(); + drop(send); + futs.push(resp); + } + loop { + match Pin::new(&mut futs).poll_next(cx) { + Poll::Pending | Poll::Ready(None) => break, + r @ Poll::Ready(Some(Ok(_))) | r @ Poll::Ready(Some(Err(_))) => { + eprintln!("{:?}", r); + } + } + } + Poll::Pending + }); + future.await?; + Ok(()) +} + +fn main() { + env_logger::init(); + let rt = tokio::runtime::Runtime::new().unwrap(); + loop { + honggfuzz::fuzz!(|data: &[u8]| { + eprintln!("{:?}", rt.block_on(run(data))); + }); + } +} diff --git a/tests/h2-support/Cargo.toml b/tests/h2-support/Cargo.toml index b30c22718..04a15882f 100644 --- a/tests/h2-support/Cargo.toml +++ b/tests/h2-support/Cargo.toml @@ -9,8 +9,7 @@ h2 = { path = "../..", features = ["unstable"] } bytes = "0.4.7" env_logger = "0.5.9" -futures = "0.1.21" +futures-preview = "0.3.0-alpha.17" http = "0.1.5" string = "0.2" -tokio-io = "0.1.6" -tokio-timer = "0.1.2" +tokio = "0.2.0-alpha.1" diff --git a/tests/h2-support/src/assert.rs b/tests/h2-support/src/assert.rs index eeb89db1e..8bc6d25c7 100644 --- a/tests/h2-support/src/assert.rs +++ b/tests/h2-support/src/assert.rs @@ -1,9 +1,10 @@ - #[macro_export] macro_rules! assert_closed { ($transport:expr) => {{ - assert_eq!($transport.poll().unwrap(), None.into()); - }} + use futures::StreamExt; + + assert!($transport.next().await.is_none()); + }}; } #[macro_export] @@ -13,7 +14,7 @@ macro_rules! assert_headers { h2::frame::Frame::Headers(v) => v, f => panic!("expected HEADERS; actual={:?}", f), } - }} + }}; } #[macro_export] @@ -23,7 +24,7 @@ macro_rules! assert_data { h2::frame::Frame::Data(v) => v, f => panic!("expected DATA; actual={:?}", f), } - }} + }}; } #[macro_export] @@ -33,7 +34,7 @@ macro_rules! assert_ping { h2::frame::Frame::Ping(v) => v, f => panic!("expected PING; actual={:?}", f), } - }} + }}; } #[macro_export] @@ -43,28 +44,56 @@ macro_rules! assert_settings { h2::frame::Frame::Settings(v) => v, f => panic!("expected SETTINGS; actual={:?}", f), } - }} + }}; } #[macro_export] macro_rules! poll_err { ($transport:expr) => {{ - match $transport.poll() { - Err(e) => e, + use futures::StreamExt; + match $transport.next().await { + Some(Err(e)) => e, frame => panic!("expected error; actual={:?}", frame), } - }} + }}; } #[macro_export] macro_rules! poll_frame { ($type: ident, $transport:expr) => {{ + use futures::StreamExt; use h2::frame::Frame; - use futures::Async; - match $transport.poll() { - Ok(Async::Ready(Some(Frame::$type(frame)))) => frame, + match $transport.next().await { + Some(Ok(Frame::$type(frame))) => frame, frame => panic!("unexpected frame; actual={:?}", frame), } - }} + }}; +} + +#[macro_export] +macro_rules! assert_default_settings { + ($settings: expr) => {{ + assert_frame_eq($settings, frame::Settings::default()); + }}; +} + +use h2::frame::Frame; + +pub fn assert_frame_eq, U: Into>(t: T, u: U) { + let actual: Frame = t.into(); + let expected: Frame = u.into(); + match (actual, expected) { + (Frame::Data(a), Frame::Data(b)) => { + assert_eq!( + a.payload().len(), + b.payload().len(), + "assert_frame_eq data payload len" + ); + assert_eq!(a, b, "assert_frame_eq"); + } + (a, b) => { + assert_eq!(a, b, "assert_frame_eq"); + } + } } diff --git a/tests/h2-support/src/client_ext.rs b/tests/h2-support/src/client_ext.rs index 9a4d6f993..43203010b 100644 --- a/tests/h2-support/src/client_ext.rs +++ b/tests/h2-support/src/client_ext.rs @@ -11,8 +11,8 @@ pub trait SendRequestExt { impl SendRequestExt for SendRequest where - B: IntoBuf, - B::Buf: 'static, + B: IntoBuf + Unpin, + B::Buf: Unpin + 'static, { fn get(&mut self, uri: &str) -> ResponseFuture { let req = Request::builder() diff --git a/tests/h2-support/src/future_ext.rs b/tests/h2-support/src/future_ext.rs index f08710da4..9f659b344 100644 --- a/tests/h2-support/src/future_ext.rs +++ b/tests/h2-support/src/future_ext.rs @@ -1,220 +1,54 @@ -use futures::{Async, Future, Poll}; - -use std::fmt; +use futures::FutureExt; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; /// Future extension helpers that are useful for tests -pub trait FutureExt: Future { - /// Panic on error - fn unwrap(self) -> Unwrap - where - Self: Sized, - Self::Error: fmt::Debug, - { - Unwrap { - inner: self, - } - } - - /// Panic on success, yielding the content of an `Err`. - fn unwrap_err(self) -> UnwrapErr - where - Self: Sized, - Self::Error: fmt::Debug, - { - UnwrapErr { - inner: self, - } - } - - /// Panic on success, with a message. - fn expect_err(self, msg: T) -> ExpectErr - where - Self: Sized, - Self::Error: fmt::Debug, - T: fmt::Display, - { - ExpectErr{ - inner: self, - msg: msg.to_string(), - } - } - - /// Panic on error, with a message. - fn expect(self, msg: T) -> Expect - where - Self: Sized, - Self::Error: fmt::Debug, - T: fmt::Display, - { - Expect { - inner: self, - msg: msg.to_string(), - } - } - +pub trait TestFuture: Future { /// Drive `other` by polling `self`. /// /// `self` must not resolve before `other` does. - fn drive(self, other: T) -> Drive + fn drive(&mut self, other: T) -> Drive<'_, Self, T> where T: Future, - T::Error: fmt::Debug, - Self: Future + Sized, - Self::Error: fmt::Debug, + Self: Future + Sized, { Drive { - driver: Some(self), - future: other, - } - } - - /// Wrap this future in one that will yield NotReady once before continuing. - /// - /// This allows the executor to poll other futures before trying this one - /// again. - fn yield_once(self) -> Box> - where - Self: Future + Sized + 'static, - { - Box::new(super::util::yield_once().then(move |_| self)) - } -} - -impl FutureExt for T {} - -// ===== Unwrap ====== - -/// Panic on error -pub struct Unwrap { - inner: T, -} - -impl Future for Unwrap -where - T: Future, - T::Item: fmt::Debug, - T::Error: fmt::Debug, -{ - type Item = T::Item; - type Error = (); - - fn poll(&mut self) -> Poll { - Ok(self.inner.poll().unwrap()) - } -} - -// ===== UnwrapErr ====== - -/// Panic on success. -pub struct UnwrapErr { - inner: T, -} - -impl Future for UnwrapErr -where - T: Future, - T::Item: fmt::Debug, - T::Error: fmt::Debug, -{ - type Item = T::Error; - type Error = (); - - fn poll(&mut self) -> Poll { - match self.inner.poll() { - Ok(Async::Ready(v)) => panic!("Future::unwrap_err() on an Ok value: {:?}", v), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => Ok(Async::Ready(e)), + driver: self, + future: Box::pin(other), } } } - - -// ===== Expect ====== - -/// Panic on error -pub struct Expect { - inner: T, - msg: String, -} - -impl Future for Expect -where - T: Future, - T::Item: fmt::Debug, - T::Error: fmt::Debug, -{ - type Item = T::Item; - type Error = (); - - fn poll(&mut self) -> Poll { - Ok(self.inner.poll().expect(&self.msg)) - } -} - -// ===== ExpectErr ====== - -/// Panic on success -pub struct ExpectErr { - inner: T, - msg: String, -} - -impl Future for ExpectErr -where - T: Future, - T::Item: fmt::Debug, - T::Error: fmt::Debug, -{ - type Item = T::Error; - type Error = (); - - fn poll(&mut self) -> Poll { - match self.inner.poll() { - Ok(Async::Ready(v)) => panic!("{}: {:?}", self.msg, v), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => Ok(Async::Ready(e)), - } - } -} +impl TestFuture for T {} // ===== Drive ====== /// Drive a future to completion while also polling the driver /// /// This is useful for H2 futures that also require the connection to be polled. -pub struct Drive { - driver: Option, - future: U, +pub struct Drive<'a, T, U> { + driver: &'a mut T, + future: Pin>, } -impl Future for Drive +impl<'a, T, U> Future for Drive<'a, T, U> where - T: Future, + T: Future + Unpin, U: Future, - T::Error: fmt::Debug, - U::Error: fmt::Debug, { - type Item = (T, U::Item); - type Error = (); + type Output = U::Output; - fn poll(&mut self) -> Poll { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut looped = false; - loop { - match self.future.poll() { - Ok(Async::Ready(val)) => { - // Get the driver - let driver = self.driver.take().unwrap(); - - return Ok((driver, val).into()); - }, - Ok(_) => {}, - Err(e) => panic!("unexpected error; {:?}", e), + match self.future.poll_unpin(cx) { + Poll::Ready(val) => return Poll::Ready(val), + Poll::Pending => {} } - match self.driver.as_mut().unwrap().poll() { - Ok(Async::Ready(_)) => { + match self.driver.poll_unpin(cx) { + Poll::Ready(_) => { if looped { // Try polling the future one last time panic!("driver resolved before future") @@ -222,12 +56,11 @@ where looped = true; continue; } - }, - Ok(Async::NotReady) => {}, - Err(e) => panic!("unexpected error; {:?}", e), + } + Poll::Pending => {} } - return Ok(Async::NotReady); + return Poll::Pending; } } } diff --git a/tests/h2-support/src/lib.rs b/tests/h2-support/src/lib.rs index 6b572fb09..98025ae51 100644 --- a/tests/h2-support/src/lib.rs +++ b/tests/h2-support/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(async_await)] //! Utilities to support tests. #[macro_use] @@ -6,17 +7,16 @@ pub mod assert; pub mod raw; pub mod frames; -pub mod prelude; pub mod mock; pub mod mock_io; -pub mod notify; +pub mod prelude; pub mod util; mod client_ext; mod future_ext; -pub use crate::client_ext::{SendRequestExt}; -pub use crate::future_ext::{FutureExt, Unwrap}; +pub use crate::client_ext::SendRequestExt; +pub use crate::future_ext::TestFuture; pub type WindowSize = usize; pub const DEFAULT_WINDOW_SIZE: WindowSize = (1 << 16) - 1; diff --git a/tests/h2-support/src/mock.rs b/tests/h2-support/src/mock.rs index ea5723c52..19ae80446 100644 --- a/tests/h2-support/src/mock.rs +++ b/tests/h2-support/src/mock.rs @@ -1,18 +1,21 @@ -use crate::{frames, FutureExt, SendFrame}; +use crate::SendFrame; -use h2::{self, RecvError, SendError}; use h2::frame::{self, Frame}; +use h2::{self, RecvError, SendError}; -use futures::{Async, Future, Poll, Stream}; -use futures::sync::oneshot; -use futures::task::{self, Task}; +use futures::future::poll_fn; +use futures::{ready, Stream, StreamExt}; -use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_io::io::read_exact; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::timer::Delay; -use std::{cmp, fmt, io, usize}; -use std::io::ErrorKind::WouldBlock; +use super::assert::assert_frame_eq; +use futures::executor::block_on; +use std::pin::Pin; use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll, Waker}; +use std::time::{Duration, Instant}; +use std::{cmp, io, usize}; /// A mock I/O #[derive(Debug)] @@ -36,19 +39,19 @@ struct Inner { rx: Vec, /// Notify when data is ready to be received. - rx_task: Option, + rx_task: Option, /// Data written by the `h2` library to be read by the test case. tx: Vec, /// Notify when data is written. This notifies the test case waiters. - tx_task: Option, + tx_task: Option, - /// Number of bytes that can be written before `write` returns `NotReady`. + /// Number of bytes that can be written before `write` returns `Poll::Pending`. tx_rem: usize, /// Task to notify when write capacity becomes available. - tx_rem_task: Option, + tx_rem_task: Option, /// True when the pipe is closed. closed: bool, @@ -80,9 +83,7 @@ pub fn new_with_write_capacity(cap: usize) -> (Mock, Handle) { }; let handle = Handle { - codec: h2::Codec::new(Pipe { - inner, - }), + codec: h2::Codec::new(Pipe { inner }), }; (mock, handle) @@ -97,207 +98,258 @@ impl Handle { } /// Send a frame - pub fn send(&mut self, item: SendFrame) -> Result<(), SendError> { + pub async fn send(&mut self, item: SendFrame) -> Result<(), SendError> { // Queue the frame self.codec.buffer(item).unwrap(); // Flush the frame - assert!(self.codec.flush()?.is_ready()); - + poll_fn(|cx| { + let p = self.codec.flush(cx); + assert!(p.is_ready()); + p + }) + .await?; Ok(()) } /// Writes the client preface - pub fn write_preface(&mut self) { - use std::io::Write; - - // Write the connnection preface - self.codec.get_mut().write(PREFACE).unwrap(); + pub async fn write_preface(&mut self) { + self.codec.get_mut().write_all(PREFACE).await.unwrap(); } /// Read the client preface - pub fn read_preface(self) -> Box> { - let buf = vec![0; PREFACE.len()]; - let ret = read_exact(self, buf).and_then(|(me, buf)| { - assert_eq!(buf, PREFACE); - Ok(me) - }); + pub async fn read_preface(&mut self) -> io::Result<()> { + let mut buf = vec![0u8; PREFACE.len()]; + self.read_exact(&mut buf).await?; + assert_eq!(buf, PREFACE); + Ok(()) + } - Box::new(ret) + pub async fn recv_frame>(&mut self, expected: F) { + let frame = self.next().await.unwrap().unwrap(); + assert_frame_eq(frame, expected); + } + + pub async fn send_frame>(&mut self, frame: F) { + self.send(frame.into()).await.unwrap(); + } + + pub async fn recv_eof(&mut self) { + let frame = self.next().await; + assert!(frame.is_none()); + } + + pub async fn send_bytes(&mut self, data: &[u8]) { + use bytes::Buf; + use std::io::Cursor; + + let buf: Vec<_> = data.into(); + let mut buf = Cursor::new(buf); + + poll_fn(move |cx| { + while buf.has_remaining() { + let res = Pin::new(self.codec.get_mut()) + .poll_write_buf(cx, &mut buf) + .map_err(|e| panic!("write err={:?}", e)); + + ready!(res).unwrap(); + } + + Poll::Ready(()) + }) + .await; } /// Perform the H2 handshake - pub fn assert_client_handshake( - self, - ) -> Box> { + pub async fn assert_client_handshake(&mut self) -> frame::Settings { self.assert_client_handshake_with_settings(frame::Settings::default()) + .await } /// Perform the H2 handshake - pub fn assert_client_handshake_with_settings( - mut self, - settings: T, - ) -> Box> + pub async fn assert_client_handshake_with_settings(&mut self, settings: T) -> frame::Settings where T: Into, { let settings = settings.into(); // Send a settings frame - self.send(settings.into()).unwrap(); - - let ret = self.read_preface() - .unwrap() - .and_then(|me| me.into_future().unwrap()) - .map(|(frame, mut me)| { - match frame { - Some(Frame::Settings(settings)) => { - // Send the ACK - let ack = frame::Settings::ack(); - - // TODO: Don't unwrap? - me.send(ack.into()).unwrap(); - - (settings, me) - }, - Some(frame) => { - panic!("unexpected frame; frame={:?}", frame); - }, - None => { - panic!("unexpected EOF"); - }, - } - }) - .then(|res| { - let (settings, me) = res.unwrap(); + self.send(settings.into()).await.unwrap(); + self.read_preface().await.unwrap(); + + let settings = match self.next().await { + Some(frame) => match frame.unwrap() { + Frame::Settings(settings) => { + // Send the ACK + let ack = frame::Settings::ack(); - me.into_future() - .map_err(|_| unreachable!("all previous futures unwrapped")) - .map(|(frame, me)| { - let f = assert_settings!(frame.unwrap()); + // TODO: Don't unwrap? + self.send(ack.into()).await.unwrap(); - // Is ACK - assert!(f.is_ack()); + settings + } + frame => { + panic!("unexpected frame; frame={:?}", frame); + } + }, + None => { + panic!("unexpected EOF"); + } + }; - (settings, me) - }) - }); + let frame = self.next().await.unwrap().unwrap(); + let f = assert_settings!(frame); - Box::new(ret) - } + // Is ACK + assert!(f.is_ack()); + settings + } /// Perform the H2 handshake - pub fn assert_server_handshake( - self, - ) -> Box> { + pub async fn assert_server_handshake(&mut self) -> frame::Settings { self.assert_server_handshake_with_settings(frame::Settings::default()) + .await } /// Perform the H2 handshake - pub fn assert_server_handshake_with_settings( - mut self, - settings: T, - ) -> Box> + pub async fn assert_server_handshake_with_settings(&mut self, settings: T) -> frame::Settings where T: Into, { - self.write_preface(); + self.write_preface().await; let settings = settings.into(); - self.send(settings.into()).unwrap(); - - let ret = self.into_future() - .unwrap() - .map(|(frame, mut me)| { - match frame { - Some(Frame::Settings(settings)) => { - // Send the ACK - let ack = frame::Settings::ack(); - - // TODO: Don't unwrap? - me.send(ack.into()).unwrap(); - - (settings, me) - }, - Some(frame) => { - panic!("unexpected frame; frame={:?}", frame); - }, - None => { - panic!("unexpected EOF"); - }, + self.send(settings.into()).await.unwrap(); + + let frame = self.next().await; + let settings = match frame { + Some(frame) => match frame.unwrap() { + Frame::Settings(settings) => { + // Send the ACK + let ack = frame::Settings::ack(); + + // TODO: Don't unwrap? + self.send(ack.into()).await.unwrap(); + + settings } - }) - .then(|res| { - let (settings, me) = res.unwrap(); + frame => panic!("unexpected frame; frame={:?}", frame), + }, + None => panic!("unexpected EOF"), + }; + let frame = self.next().await; + let f = assert_settings!(frame.unwrap().unwrap()); - me.into_future() - .map_err(|e| panic!("error: {:?}", e)) - .map(|(frame, me)| { - let f = assert_settings!(frame.unwrap()); + // Is ACK + assert!(f.is_ack()); - // Is ACK - assert!(f.is_ack()); + settings + } + + pub async fn ping_pong(&mut self, payload: [u8; 8]) { + self.send_frame(crate::frames::ping(payload)).await; + self.recv_frame(crate::frames::ping(payload).pong()).await; + } + + pub async fn buffer_bytes(&mut self, num: usize) { + // Set tx_rem to num + { + let mut i = self.codec.get_mut().inner.lock().unwrap(); + i.tx_rem = num; + } + + poll_fn(move |cx| { + { + let mut inner = self.codec.get_mut().inner.lock().unwrap(); + if inner.tx_rem == 0 { + inner.tx_rem = usize::MAX; + } else { + inner.tx_task = Some(cx.waker().clone()); + return Poll::Pending; + } + } - (settings, me) - }) - }); + Poll::Ready(()) + }) + .await; + } + + pub async fn unbounded_bytes(&mut self) { + let mut i = self.codec.get_mut().inner.lock().unwrap(); + i.tx_rem = usize::MAX; - Box::new(ret) + if let Some(task) = i.tx_rem_task.take() { + task.wake(); + } } } impl Stream for Handle { - type Item = Frame; - type Error = RecvError; + type Item = Result; - fn poll(&mut self) -> Poll, RecvError> { - self.codec.poll() + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.codec).poll_next(cx) } } -impl io::Read for Handle { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.codec.get_mut().read(buf) +impl AsyncRead for Handle { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(self.codec.get_mut()).poll_read(cx, buf) } } -impl AsyncRead for Handle {} - -impl io::Write for Handle { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.codec.get_mut().write(buf) +impl AsyncWrite for Handle { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(self.codec.get_mut()).poll_write(cx, buf) } - fn flush(&mut self) -> io::Result<()> { - Ok(()) + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(self.codec.get_mut()).poll_flush(cx) } -} -impl AsyncWrite for Handle { - fn shutdown(&mut self) -> Poll<(), io::Error> { - use std::io::Write; - tokio_io::try_nb!(self.flush()); - Ok(().into()) + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(self.codec.get_mut()).poll_shutdown(cx) } } impl Drop for Handle { fn drop(&mut self) { - assert!(self.codec.shutdown().unwrap().is_ready()); + block_on(async { + poll_fn(|cx| { + assert!(self.codec.shutdown(cx).is_ready()); - let mut me = self.codec.get_mut().inner.lock().unwrap(); - me.closed = true; + let mut me = self.codec.get_mut().inner.lock().unwrap(); + me.closed = true; - if let Some(task) = me.rx_task.take() { - task.notify(); - } + if let Some(task) = me.rx_task.take() { + task.wake(); + } + Poll::Ready(()) + }) + .await; + }); } } // ===== impl Mock ===== -impl io::Read for Mock { - fn read(&mut self, buf: &mut [u8]) -> io::Result { +impl AsyncRead for Mock { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { assert!( buf.len() > 0, "attempted read with zero length buffer... wut?" @@ -307,34 +359,36 @@ impl io::Read for Mock { if me.rx.is_empty() { if me.closed { - return Ok(0); + return Poll::Ready(Ok(0)); } - me.rx_task = Some(task::current()); - return Err(WouldBlock.into()); + me.rx_task = Some(cx.waker().clone()); + return Poll::Pending; } let n = cmp::min(buf.len(), me.rx.len()); buf[..n].copy_from_slice(&me.rx[..n]); me.rx.drain(..n); - Ok(n) + Poll::Ready(Ok(n)) } } -impl AsyncRead for Mock {} - -impl io::Write for Mock { - fn write(&mut self, mut buf: &[u8]) -> io::Result { +impl AsyncWrite for Mock { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: &[u8], + ) -> Poll> { let mut me = self.pipe.inner.lock().unwrap(); if me.closed { - return Ok(buf.len()); + return Poll::Ready(Ok(buf.len())); } if me.tx_rem == 0 { - me.tx_rem_task = Some(task::current()); - return Err(io::ErrorKind::WouldBlock.into()); + me.tx_rem_task = Some(cx.waker().clone()); + return Poll::Pending; } if buf.len() > me.tx_rem { @@ -345,22 +399,18 @@ impl io::Write for Mock { me.tx_rem -= buf.len(); if let Some(task) = me.tx_task.take() { - task.notify(); + task.wake(); } - Ok(buf.len()) + Poll::Ready(Ok(buf.len())) } - fn flush(&mut self) -> io::Result<()> { - Ok(()) + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } -} -impl AsyncWrite for Mock { - fn shutdown(&mut self) -> Poll<(), io::Error> { - use std::io::Write; - tokio_io::try_nb!(self.flush()); - Ok(().into()) + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } } @@ -370,15 +420,19 @@ impl Drop for Mock { me.closed = true; if let Some(task) = me.tx_task.take() { - task.notify(); + task.wake(); } } } // ===== impl Pipe ===== -impl io::Read for Pipe { - fn read(&mut self, buf: &mut [u8]) -> io::Result { +impl AsyncRead for Pipe { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { assert!( buf.len() > 0, "attempted read with zero length buffer... wut?" @@ -386,382 +440,48 @@ impl io::Read for Pipe { let mut me = self.inner.lock().unwrap(); - if me.tx.is_empty() { if me.closed { - return Ok(0); + return Poll::Ready(Ok(0)); } - me.tx_task = Some(task::current()); - return Err(WouldBlock.into()); + me.tx_task = Some(cx.waker().clone()); + return Poll::Pending; } let n = cmp::min(buf.len(), me.tx.len()); buf[..n].copy_from_slice(&me.tx[..n]); me.tx.drain(..n); - Ok(n) + Poll::Ready(Ok(n)) } } -impl AsyncRead for Pipe {} - -impl io::Write for Pipe { - fn write(&mut self, buf: &[u8]) -> io::Result { +impl AsyncWrite for Pipe { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { let mut me = self.inner.lock().unwrap(); me.rx.extend(buf); if let Some(task) = me.rx_task.take() { - task.notify(); + task.wake(); } - Ok(buf.len()) + Poll::Ready(Ok(buf.len())) } - fn flush(&mut self) -> io::Result<()> { - Ok(()) + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } -} -impl AsyncWrite for Pipe { - fn shutdown(&mut self) -> Poll<(), io::Error> { - use std::io::Write; - tokio_io::try_nb!(self.flush()); - Ok(().into()) + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) } } -pub trait HandleFutureExt { - fn recv_settings(self) - -> RecvFrame, Handle), Error = ()>>> - where - Self: Sized + 'static, - Self: Future, - Self::Error: fmt::Debug, - { - self.recv_custom_settings(frame::Settings::default()) - } - - fn recv_custom_settings(self, settings: T) - -> RecvFrame, Handle), Error = ()>>> - where - Self: Sized + 'static, - Self: Future, - Self::Error: fmt::Debug, - T: Into, - { - let map = self - .map(|(settings, handle)| (Some(settings.into()), handle)) - .unwrap(); - - let boxed: Box, Handle), Error = ()>> = - Box::new(map); - RecvFrame { - inner: boxed, - frame: Some(settings.into().into()), - } - } - - fn ignore_settings(self) -> Box> - where - Self: Sized + 'static, - Self: Future, - Self::Error: fmt::Debug, - { - Box::new(self.map(|(_settings, handle)| handle).unwrap()) - } - - fn recv_frame(self, frame: T) -> RecvFrame<::Future> - where - Self: IntoRecvFrame + Sized, - T: Into, - { - self.into_recv_frame(Some(frame.into())) - } - - fn recv_eof(self) -> RecvFrame<::Future> - where - Self: IntoRecvFrame + Sized, - { - self.into_recv_frame(None) - } - - fn send_frame(self, frame: T) -> SendFrameFut - where - Self: Sized, - T: Into, - { - SendFrameFut { - inner: self, - frame: Some(frame.into()), - } - } - - fn send_bytes(self, data: &[u8]) -> Box> - where - Self: Future + Sized + 'static, - Self::Error: fmt::Debug, - { - use bytes::Buf; - use futures::future::poll_fn; - use std::io::Cursor; - - let buf: Vec<_> = data.into(); - let mut buf = Cursor::new(buf); - - Box::new(self.and_then(move |handle| { - let mut handle = Some(handle); - - poll_fn(move || { - while buf.has_remaining() { - let res = handle.as_mut().unwrap() - .codec.get_mut() - .write_buf(&mut buf) - .map_err(|e| panic!("write err={:?}", e)); - - futures::try_ready!(res); - } - - Ok(handle.take().unwrap().into()) - }) - })) - } - - fn ping_pong(self, payload: [u8; 8]) -> RecvFrame< as IntoRecvFrame>::Future> - where - Self: Future + Sized + 'static, - Self::Error: fmt::Debug, - { - self.send_frame(frames::ping(payload)) - .recv_frame(frames::ping(payload).pong()) - } - - fn idle_ms(self, ms: usize) -> Box> - where - Self: Sized + 'static, - Self: Future, - Self::Error: fmt::Debug, - { - use std::thread; - use std::time::Duration; - - - Box::new(self.and_then(move |handle| { - // This is terrible... but oh well - let (tx, rx) = oneshot::channel(); - - thread::spawn(move || { - thread::sleep(Duration::from_millis(ms as u64)); - tx.send(()).unwrap(); - }); - - Idle { - handle: Some(handle), - timeout: rx, - }.map_err(|_| unreachable!()) - })) - } - - fn buffer_bytes(self, num: usize) -> Box> - where Self: Sized + 'static, - Self: Future, - Self::Error: fmt::Debug, - { - use futures::future::poll_fn; - - Box::new(self.and_then(move |mut handle| { - // Set tx_rem to num - { - let mut i = handle.codec.get_mut().inner.lock().unwrap(); - i.tx_rem = num; - } - - let mut handle = Some(handle); - - poll_fn(move || { - { - let mut inner = handle.as_mut().unwrap() - .codec.get_mut().inner.lock().unwrap(); - - if inner.tx_rem == 0 { - inner.tx_rem = usize::MAX; - } else { - inner.tx_task = Some(task::current()); - return Ok(Async::NotReady); - } - } - - Ok(handle.take().unwrap().into()) - }) - })) - } - - fn unbounded_bytes(self) -> Box> - where Self: Sized + 'static, - Self: Future, - Self::Error: fmt::Debug, - { - Box::new(self.and_then(|mut handle| { - { - let mut i = handle.codec.get_mut().inner.lock().unwrap(); - i.tx_rem = usize::MAX; - - if let Some(task) = i.tx_rem_task.take() { - task.notify(); - } - } - - Ok(handle.into()) - })) - } - - fn then_notify(self, tx: oneshot::Sender<()>) -> Box> - where Self: Sized + 'static, - Self: Future, - Self::Error: fmt::Debug, - { - Box::new(self.map(move |handle| { - tx.send(()).unwrap(); - handle - })) - } - - fn wait_for(self, other: F) -> Box> - where - F: Future + 'static, - Self: Future + Sized + 'static - { - Box::new(self.then(move |result| { - other.then(move |_| result) - })) - } - - fn close(self) -> Box> - where - Self: Future + Sized + 'static, - { - Box::new(self.map(drop)) - } -} - -pub struct RecvFrame { - inner: T, - frame: Option, -} - -impl Future for RecvFrame -where - T: Future, Handle)>, - T::Error: fmt::Debug, -{ - type Item = Handle; - type Error = (); - - fn poll(&mut self) -> Poll { - use self::Frame::Data; - - let (frame, handle) = match self.inner.poll().unwrap() { - Async::Ready((frame, handle)) => (frame, handle), - Async::NotReady => return Ok(Async::NotReady), - }; - - match (frame, &self.frame) { - (Some(Data(ref a)), &Some(Data(ref b))) => { - assert_eq!(a.payload().len(), b.payload().len(), "recv_frame data payload len"); - assert_eq!(a, b, "recv_frame"); - } - (ref a, b) => { - assert_eq!(a, b, "recv_frame"); - } - } - - Ok(Async::Ready(handle)) - } -} - -pub struct SendFrameFut { - inner: T, - frame: Option, -} - -impl Future for SendFrameFut -where - T: Future, - T::Error: fmt::Debug, -{ - type Item = Handle; - type Error = (); - - fn poll(&mut self) -> Poll { - let mut handle = match self.inner.poll().unwrap() { - Async::Ready(handle) => handle, - Async::NotReady => return Ok(Async::NotReady), - }; - handle.send(self.frame.take().unwrap()).unwrap(); - Ok(Async::Ready(handle)) - } -} - -pub struct Idle { - handle: Option, - timeout: oneshot::Receiver<()>, -} - -impl Future for Idle { - type Item = Handle; - type Error = (); - - fn poll(&mut self) -> Poll { - if self.timeout.poll().unwrap().is_ready() { - return Ok(self.handle.take().unwrap().into()); - } - - match self.handle.as_mut().unwrap().poll() { - Ok(Async::NotReady) => Ok(Async::NotReady), - res => { - panic!("Idle received unexpected frame on handle; frame={:?}", res); - }, - } - } -} - -impl HandleFutureExt for T -where - T: Future + 'static, -{ -} - -pub trait IntoRecvFrame { - type Future: Future; - fn into_recv_frame(self, frame: Option) -> RecvFrame; -} - -impl IntoRecvFrame for Handle { - type Future = ::futures::stream::StreamFuture; - - fn into_recv_frame(self, frame: Option) -> RecvFrame { - RecvFrame { - inner: self.into_future(), - frame: frame, - } - } -} - -impl IntoRecvFrame for T -where - T: Future + 'static, - T::Error: fmt::Debug, -{ - type Future = Box, Handle), Error = ()>>; - - fn into_recv_frame(self, frame: Option) -> RecvFrame { - let into_fut = Box::new( - self.unwrap() - .and_then(|handle| handle.into_future().unwrap()), - ); - RecvFrame { - inner: into_fut, - frame: frame, - } - } +pub async fn idle_ms(ms: u64) { + Delay::new(Instant::now() + Duration::from_millis(ms)).await } diff --git a/tests/h2-support/src/mock_io.rs b/tests/h2-support/src/mock_io.rs index b5565543c..8b3c95bb2 100644 --- a/tests/h2-support/src/mock_io.rs +++ b/tests/h2-support/src/mock_io.rs @@ -74,9 +74,9 @@ #![allow(deprecated)] -use std::{cmp, io}; use std::collections::VecDeque; use std::time::{Duration, Instant}; +use std::{cmp, io}; /// An I/O handle that follows a predefined script. /// @@ -85,13 +85,12 @@ use std::time::{Duration, Instant}; #[derive(Debug)] pub struct Mock { inner: Inner, - tokio: tokio::Inner, - r#async: Option, + tokio: tokio_::Inner, } #[derive(Debug)] pub struct Handle { - inner: tokio::Handle, + inner: tokio_::Handle, } /// Builds `Mock` instances. @@ -99,9 +98,6 @@ pub struct Handle { pub struct Builder { // Sequence of actions for the Mock to take actions: VecDeque, - - // true for Tokio, false for blocking, None to auto detect - r#async: Option, } #[derive(Debug, Clone)] @@ -159,7 +155,7 @@ impl Builder { /// Build a `Mock` value paired with a handle pub fn build_with_handle(&mut self) -> (Mock, Handle) { - let (tokio, handle) = tokio::Inner::new(); + let (tokio, handle) = tokio_::Inner::new(); let src = self.clone(); @@ -169,7 +165,6 @@ impl Builder { waiting: None, }, tokio: tokio, - r#async: src.r#async, }; let handle = Handle { inner: handle }; @@ -198,45 +193,10 @@ impl Handle { } } -impl Mock { - fn sync_read(&mut self, dst: &mut [u8]) -> io::Result { - use std::thread; - - loop { - match self.inner.read(dst) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - if let Some(rem) = self.inner.remaining_wait() { - thread::sleep(rem); - } else { - // We've entered a dead lock scenario. The peer expects - // a write but we are reading. - panic!("mock_io::Mock expects write but currently blocked in read"); - } - } - ret => return ret, - } - } - } - - fn sync_write(&mut self, src: &[u8]) -> io::Result { - match self.inner.write(src) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - panic!("mock_io::Mock not currently expecting a write"); - } - ret => ret, - } - } - - /// Returns `true` if running in a futures-rs task context - fn is_async(&self) -> bool { - self.r#async.unwrap_or(tokio::is_task_ctx()) - } -} - impl Inner { fn read(&mut self, dst: &mut [u8]) -> io::Result { match self.action() { - Some(&mut Action::Read(ref mut data)) =>{ + Some(&mut Action::Read(ref mut data)) => { // Figure out how much to copy let n = cmp::min(dst.len(), data.len()); @@ -253,9 +213,7 @@ impl Inner { // Either waiting or expecting a write Err(io::ErrorKind::WouldBlock.into()) } - None => { - Ok(0) - } + None => Ok(0), } } @@ -347,55 +305,25 @@ impl Inner { } } -impl io::Read for Mock { - fn read(&mut self, dst: &mut [u8]) -> io::Result { - if self.is_async() { - tokio::async_read(self, dst) - } else { - self.sync_read(dst) - } - } -} - -impl io::Write for Mock { - fn write(&mut self, src: &[u8]) -> io::Result { - if self.is_async() { - tokio::async_write(self, src) - } else { - self.sync_write(src) - } - } - - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - // use tokio::*; -mod tokio { +mod tokio_ { use super::*; - use futures::{Future, Stream, Poll, Async}; - use futures::sync::mpsc; - use futures::task::{self, Task}; - use tokio_io::{AsyncRead, AsyncWrite}; - use tokio_timer::{Timer, Sleep}; + use futures::channel::mpsc; + use futures::{ready, FutureExt, Stream}; + use std::task::{Context, Poll, Waker}; + use tokio::io::{AsyncRead, AsyncWrite}; + use tokio::timer::Delay; - use std::io; + use std::pin::Pin; - impl Builder { - pub fn set_async(&mut self, is_async: bool) -> &mut Self { - self.r#async = Some(is_async); - self - } - } + use std::io; #[derive(Debug)] pub struct Inner { - timer: Timer, - sleep: Option, - read_wait: Option, + sleep: Option, + read_wait: Option, rx: mpsc::UnboundedReceiver, } @@ -408,11 +336,11 @@ mod tokio { impl Handle { pub fn read(&mut self, buf: &[u8]) { - mpsc::UnboundedSender::send(&mut self.tx, Action::Read(buf.into())).unwrap(); + self.tx.unbounded_send(Action::Read(buf.into())).unwrap(); } pub fn write(&mut self, buf: &[u8]) { - mpsc::UnboundedSender::send(&mut self.tx, Action::Write(buf.into())).unwrap(); + self.tx.unbounded_send(Action::Write(buf.into())).unwrap(); } } @@ -420,16 +348,9 @@ mod tokio { impl Inner { pub fn new() -> (Inner, Handle) { - // TODO: We probably want a higher resolution timer. - let timer = tokio_timer::wheel() - .tick_duration(Duration::from_millis(1)) - .max_timeout(Duration::from_secs(3600)) - .build(); - let (tx, rx) = mpsc::unbounded(); let inner = Inner { - timer: timer, sleep: None, read_wait: None, rx: rx, @@ -440,8 +361,8 @@ mod tokio { (inner, handle) } - pub(super) fn poll_action(&mut self) -> Poll, ()> { - self.rx.poll() + pub(super) fn poll_action(&mut self, cx: &mut Context) -> Poll> { + Pin::new(&mut self.rx).poll_next(cx) } } @@ -450,7 +371,7 @@ mod tokio { match self.inner.action() { Some(&mut Action::Read(_)) | None => { if let Some(task) = self.tokio.read_wait.take() { - task.notify(); + task.wake(); } } _ => {} @@ -458,106 +379,113 @@ mod tokio { } } - pub fn async_read(me: &mut Mock, dst: &mut [u8]) -> io::Result { - loop { - if let Some(ref mut sleep) = me.tokio.sleep { - let res = r#try!(sleep.poll()); - - if !res.is_ready() { - return Err(io::ErrorKind::WouldBlock.into()); + impl AsyncRead for Mock { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + loop { + if let Some(sleep) = &mut self.tokio.sleep { + ready!(sleep.poll_unpin(cx)); } - } - // If a sleep is set, it has already fired - me.tokio.sleep = None; + // If a sleep is set, it has already fired + self.tokio.sleep = None; - match me.inner.read(dst) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - if let Some(rem) = me.inner.remaining_wait() { - me.tokio.sleep = Some(me.tokio.timer.sleep(rem)); - } else { - me.tokio.read_wait = Some(task::current()); - return Err(io::ErrorKind::WouldBlock.into()); - } - } - Ok(0) => { - // TODO: Extract - match me.tokio.poll_action().unwrap() { - Async::Ready(Some(action)) => { - me.inner.actions.push_back(action); - continue; - } - Async::Ready(None) => { - return Ok(0); + match self.inner.read(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + if let Some(rem) = self.inner.remaining_wait() { + self.tokio.sleep = Some(Delay::new(Instant::now() + rem)); + } else { + self.tokio.read_wait = Some(cx.waker().clone()); + return Poll::Pending; } - Async::NotReady => { - return Err(io::ErrorKind::WouldBlock.into()); + } + Ok(0) => { + // TODO: Extract + match self.tokio.poll_action(cx) { + Poll::Ready(Some(action)) => { + self.inner.actions.push_back(action); + continue; + } + Poll::Ready(None) => { + return Poll::Ready(Ok(0)); + } + Poll::Pending => { + return Poll::Pending; + } } } + ret => return Poll::Ready(ret), } - ret => return ret, } } } - pub fn async_write(me: &mut Mock, src: &[u8]) -> io::Result { - loop { - if let Some(ref mut sleep) = me.tokio.sleep { - let res = r#try!(sleep.poll()); - - if !res.is_ready() { - return Err(io::ErrorKind::WouldBlock.into()); + impl AsyncWrite for Mock { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + if let Some(sleep) = &mut self.tokio.sleep { + ready!(sleep.poll_unpin(cx)); } - } - - // If a sleep is set, it has already fired - me.tokio.sleep = None; - match me.inner.write(src) { - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - if let Some(rem) = me.inner.remaining_wait() { - me.tokio.sleep = Some(me.tokio.timer.sleep(rem)); - } else { - panic!("unexpected WouldBlock"); - } - } - Ok(0) => { - // TODO: Is this correct? - if !me.inner.actions.is_empty() { - return Err(io::ErrorKind::WouldBlock.into()); - } + // If a sleep is set, it has already fired + self.tokio.sleep = None; - // TODO: Extract - match me.tokio.poll_action().unwrap() { - Async::Ready(Some(action)) => { - me.inner.actions.push_back(action); - continue; + match self.inner.write(buf) { + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + if let Some(rem) = self.inner.remaining_wait() { + self.tokio.sleep = Some(Delay::new(Instant::now() + rem)); + } else { + panic!("unexpected WouldBlock"); } - Async::Ready(None) => { - panic!("unexpected write"); + } + Ok(0) => { + // TODO: Is this correct? + if !self.inner.actions.is_empty() { + return Poll::Pending; } - Async::NotReady => { - return Err(io::ErrorKind::WouldBlock.into()); + + // TODO: Extract + match self.tokio.poll_action(cx) { + Poll::Ready(Some(action)) => { + self.inner.actions.push_back(action); + continue; + } + Poll::Ready(None) => { + panic!("unexpected write"); + } + Poll::Pending => return Poll::Pending, } } - } - ret => { - me.maybe_wakeup_reader(); - return ret; + ret => { + self.maybe_wakeup_reader(); + return Poll::Ready(ret); + } } } } - } - impl AsyncRead for Mock { - } + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } - impl AsyncWrite for Mock { - fn shutdown(&mut self) -> Poll<(), io::Error> { - Ok(Async::Ready(())) + fn poll_shutdown( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) } } + /* + TODO: Is this required? + /// Returns `true` if called from the context of a futures-rs Task pub fn is_task_ctx() -> bool { use std::panic; @@ -577,4 +505,5 @@ mod tokio { // Return the result r } + */ } diff --git a/tests/h2-support/src/notify.rs b/tests/h2-support/src/notify.rs deleted file mode 100644 index 6f3e96a2d..000000000 --- a/tests/h2-support/src/notify.rs +++ /dev/null @@ -1,55 +0,0 @@ -use futures::executor::{self, Notify}; - -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::Ordering::SeqCst; - -pub struct MockNotify { - inner: Arc, -} - -struct Inner { - notified: AtomicBool, -} - -impl MockNotify { - pub fn new() -> Self { - MockNotify { - inner: Arc::new(Inner { - notified: AtomicBool::new(false), - }), - } - } - - pub fn with R, R>(&self, f: F) -> R { - use futures::Async::Ready; - use futures::future::poll_fn; - - self.clear(); - - let mut f = Some(f); - - let res = executor::spawn(poll_fn(move || { - Ok::<_, ()>(Ready(f.take().unwrap()())) - })).poll_future_notify(&self.inner, 0); - - match res { - Ok(Ready(v)) => v, - _ => unreachable!(), - } - } - - pub fn clear(&self) { - self.inner.notified.store(false, SeqCst); - } - - pub fn is_notified(&self) -> bool { - self.inner.notified.load(SeqCst) - } -} - -impl Notify for Inner { - fn notify(&self, _: usize) { - self.notified.store(true, SeqCst); - } -} diff --git a/tests/h2-support/src/prelude.rs b/tests/h2-support/src/prelude.rs index 16f3bd525..866ef15ed 100644 --- a/tests/h2-support/src/prelude.rs +++ b/tests/h2-support/src/prelude.rs @@ -1,21 +1,17 @@ - // Re-export H2 crate pub use h2; -pub use h2::*; pub use h2::client; pub use h2::frame::StreamId; pub use h2::server; +pub use h2::*; // Re-export mock -pub use super::mock::{self, HandleFutureExt}; +pub use super::mock::{self, idle_ms}; // Re-export frames helpers pub use super::frames; -// Re-export mock notify -pub use super::notify::MockNotify; - // Re-export utility mod pub use super::util; @@ -23,28 +19,32 @@ pub use super::util; pub use super::{Codec, SendFrame}; // Re-export macros -pub use super::{assert_ping, assert_data, assert_headers, assert_closed, - raw_codec, poll_frame, poll_err}; +pub use super::{ + assert_closed, assert_data, assert_default_settings, assert_headers, assert_ping, poll_err, + poll_frame, raw_codec, +}; + +pub use super::assert::assert_frame_eq; // Re-export useful crates -pub use {bytes, env_logger, futures, http, tokio_io}; pub use super::mock_io; +pub use {bytes, env_logger, futures, http, tokio::io as tokio_io}; // Re-export primary future types -pub use futures::{Future, IntoFuture, Sink, Stream}; +pub use futures::{Future, Sink, Stream}; // And our Future extensions -pub use super::future_ext::{FutureExt, Unwrap}; +pub use super::future_ext::TestFuture; // Our client_ext helpers -pub use super::client_ext::{SendRequestExt}; +pub use super::client_ext::SendRequestExt; // Re-export HTTP types pub use http::{uri, HeaderMap, Method, Request, Response, StatusCode, Version}; pub use bytes::{Buf, BufMut, Bytes, BytesMut, IntoBuf}; -pub use tokio_io::{AsyncRead, AsyncWrite}; +pub use tokio::io::{AsyncRead, AsyncWrite}; pub use std::thread; pub use std::time::Duration; @@ -52,7 +52,10 @@ pub use std::time::Duration; // ===== Everything under here shouldn't be used ===== // TODO: work on deleting this code +use futures::future; pub use futures::future::poll_fn; +use futures::future::Either::*; +use std::pin::Pin; pub trait MockH2 { fn handshake(&mut self) -> &mut Self; @@ -69,29 +72,33 @@ impl MockH2 for super::mock_io::Builder { } pub trait ClientExt { - fn run(&mut self, f: F) -> Result; + fn run<'a, F: Future + Unpin + 'a>( + &'a mut self, + f: F, + ) -> Pin + 'a>>; } impl ClientExt for client::Connection where - T: AsyncRead + AsyncWrite + 'static, - B: IntoBuf + 'static, + T: AsyncRead + AsyncWrite + Unpin + 'static, + B: IntoBuf + Unpin + 'static, + B::Buf: Unpin, { - fn run(&mut self, f: F) -> Result { - use futures::future; - use futures::future::Either::*; - - let res = future::poll_fn(|| self.poll()).select2(f).wait(); - - match res { - Ok(A((_, b))) => { - // Connection is done... - b.wait() - }, - Ok(B((v, _))) => return Ok(v), - Err(A((e, _))) => panic!("err: {:?}", e), - Err(B((e, _))) => return Err(e), - } + fn run<'a, F: Future + Unpin + 'a>( + &'a mut self, + f: F, + ) -> Pin + 'a>> { + let res = future::select(self, f); + Box::pin(async { + match res.await { + Left((Ok(_), b)) => { + // Connection is done... + b.await + } + Right((v, _)) => return v, + Left((Err(e), _)) => panic!("err: {:?}", e), + } + }) } } diff --git a/tests/h2-support/src/util.rs b/tests/h2-support/src/util.rs index 4c1a705ad..b854e6e1d 100644 --- a/tests/h2-support/src/util.rs +++ b/tests/h2-support/src/util.rs @@ -1,24 +1,28 @@ use h2; -use string::{String, TryFrom}; use bytes::Bytes; -use futures::{Async, Future, Poll}; +use futures::ready; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use string::{String, TryFrom}; pub fn byte_str(s: &str) -> String { String::try_from(Bytes::from(s)).unwrap() } -pub fn yield_once() -> impl Future { +pub async fn yield_once() { let mut yielded = false; - futures::future::poll_fn(move || { + futures::future::poll_fn(move |cx| { if yielded { - Ok(Async::Ready(())) + Poll::Ready(()) } else { yielded = true; - futures::task::current().notify(); - Ok(Async::NotReady) + cx.waker().clone().wake(); + Poll::Pending } }) + .await; } pub fn wait_for_capacity(stream: h2::SendStream, target: usize) -> WaitForCapacity { @@ -40,18 +44,17 @@ impl WaitForCapacity { } impl Future for WaitForCapacity { - type Item = h2::SendStream; - type Error = (); + type Output = h2::SendStream; - fn poll(&mut self) -> Poll { - let _ = futures::try_ready!(self.stream().poll_capacity().map_err(|_| panic!())); + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let _ = ready!(self.stream().poll_capacity(cx)).unwrap(); let act = self.stream().capacity(); if act >= self.target { - return Ok(self.stream.take().unwrap().into()); + return Poll::Ready(self.stream.take().unwrap().into()); } - Ok(Async::NotReady) + Poll::Pending } } diff --git a/tests/h2-tests/Cargo.toml b/tests/h2-tests/Cargo.toml index 059bae131..5876b5ec3 100644 --- a/tests/h2-tests/Cargo.toml +++ b/tests/h2-tests/Cargo.toml @@ -10,4 +10,5 @@ edition = "2018" [dev-dependencies] h2-support = { path = "../h2-support" } log = "0.4.1" -tokio = "0.1.8" +futures-preview = "0.3.0-alpha.17" +tokio = "0.2.0-alpha.1" diff --git a/tests/h2-tests/tests/client_request.rs b/tests/h2-tests/tests/client_request.rs index 7a534a577..1574114cb 100644 --- a/tests/h2-tests/tests/client_request.rs +++ b/tests/h2-tests/tests/client_request.rs @@ -1,7 +1,14 @@ +#![feature(async_await)] + +use futures::future::{join, ready, select, Either}; +use futures::stream::FuturesUnordered; +use futures::StreamExt; use h2_support::prelude::*; +use std::pin::Pin; +use std::task::Context; -#[test] -fn handshake() { +#[tokio::test] +async fn handshake() { let _ = env_logger::try_init(); let mock = mock_io::Builder::new() @@ -9,61 +16,60 @@ fn handshake() { .write(SETTINGS_ACK) .build(); - let (_client, h2) = client::handshake(mock).wait().unwrap(); + let (_client, h2) = client::handshake(mock).await.unwrap(); log::trace!("hands have been shook"); // At this point, the connection should be closed - h2.wait().unwrap(); + h2.await.unwrap(); } -#[test] -fn client_other_thread() { +#[tokio::test] +async fn client_other_thread() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") .eos(), ) - .send_frame(frames::headers(1).response(200).eos()) - .close(); + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; - let h2 = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, h2)| { - ::std::thread::spawn(move || { - let request = Request::builder() - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); - let res = client - .send_request(request, true) - .unwrap().0 - .wait() - .expect("request"); - assert_eq!(res.status(), StatusCode::OK); - }); - - h2.expect("h2") + let h2 = async move { + let (mut client, h2) = client::handshake(io).await.unwrap(); + tokio::spawn(async move { + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + let _res = client + .send_request(request, true) + .unwrap() + .0 + .await + .expect("request"); }); - h2.join(srv).wait().expect("wait"); + h2.await.expect("h2"); + }; + join(srv, h2).await; } -#[test] -fn recv_invalid_server_stream_id() { +#[tokio::test] +async fn recv_invalid_server_stream_id() { let _ = env_logger::try_init(); let mock = mock_io::Builder::new() .handshake() // Write GET / .write(&[ - 0, 0, 0x10, 1, 5, 0, 0, 0, 1, 0x82, 0x87, 0x41, 0x8B, 0x9D, 0x29, - 0xAC, 0x4B, 0x8F, 0xA8, 0xE9, 0x19, 0x97, 0x21, 0xE9, 0x84, + 0, 0, 0x10, 1, 5, 0, 0, 0, 1, 0x82, 0x87, 0x41, 0x8B, 0x9D, 0x29, 0xAC, 0x4B, 0x8F, + 0xA8, 0xE9, 0x19, 0x97, 0x21, 0xE9, 0x84, ]) .write(SETTINGS_ACK) // Read response @@ -72,7 +78,7 @@ fn recv_invalid_server_stream_id() { .write(&[0, 0, 8, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]) .build(); - let (mut client, h2) = client::handshake(mock).wait().unwrap(); + let (mut client, h2) = client::handshake(mock).await.unwrap(); // Send the request let request = Request::builder() @@ -84,431 +90,429 @@ fn recv_invalid_server_stream_id() { let (response, _) = client.send_request(request, true).unwrap(); // The connection errors - assert!(h2.wait().is_err()); + assert!(h2.await.is_err()); // The stream errors - assert!(response.wait().is_err()); + assert!(response.await.is_err()); } -#[test] -fn request_stream_id_overflows() { +#[tokio::test] +async fn request_stream_id_overflows() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - - - let h2 = client::Builder::new() - .initial_stream_id(::std::u32::MAX >> 1) - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|(mut client, h2)| { - let request = Request::builder() - .method(Method::GET) - .uri("https://example.com/") - .body(()) - .unwrap(); - - // first request is allowed - let (response, _) = client.send_request(request, true).unwrap(); - - h2.drive(response).and_then(move |(h2, _)| { - let request = Request::builder() - .method(Method::GET) - .uri("https://example.com/") - .body(()) - .unwrap(); + let (io, mut srv) = mock::new(); + + let h2 = async move { + let (mut client, mut h2) = client::Builder::new() + .initial_stream_id(::std::u32::MAX >> 1) + .handshake::<_, Bytes>(io) + .await + .unwrap(); + let request = Request::builder() + .method(Method::GET) + .uri("https://example.com/") + .body(()) + .unwrap(); - // second cannot use the next stream id, it's over + // first request is allowed + let (response, _) = client.send_request(request, true).unwrap(); + let _x = h2.drive(response).await.unwrap(); - let poll_err = client.poll_ready().unwrap_err(); - assert_eq!(poll_err.to_string(), "user error: stream ID overflowed"); + let request = Request::builder() + .method(Method::GET) + .uri("https://example.com/") + .body(()) + .unwrap(); + // second cannot use the next stream id, it's over + let poll_err = poll_fn(|cx| client.poll_ready(cx)).await.unwrap_err(); + assert_eq!(poll_err.to_string(), "user error: stream ID overflowed"); - let err = client.send_request(request, true).unwrap_err(); - assert_eq!(err.to_string(), "user error: stream ID overflowed"); + let err = client.send_request(request, true).unwrap_err(); + assert_eq!(err.to_string(), "user error: stream ID overflowed"); - h2.expect("h2").map(|ret| { - // Hold on to the `client` handle to avoid sending a GO_AWAY - // frame. - drop(client); - ret - }) - }) - }); + h2.await.unwrap(); + }; - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(::std::u32::MAX >> 1) .request("GET", "https://example.com/") .eos(), ) - .send_frame( - frames::headers(::std::u32::MAX >> 1) - .response(200) - .eos() - ) - .idle_ms(10) - .close(); + .await; + srv.send_frame(frames::headers(::std::u32::MAX >> 1).response(200).eos()) + .await; + idle_ms(10).await; + }; - h2.join(srv).wait().expect("wait"); + join(srv, h2).await; } -#[test] -fn client_builder_max_concurrent_streams() { +#[tokio::test] +async fn client_builder_max_concurrent_streams() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); let mut settings = frame::Settings::default(); settings.set_max_concurrent_streams(Some(1)); - let srv = srv - .assert_client_handshake() - .unwrap() - .recv_custom_settings(settings) - .recv_frame( + let srv = async move { + let rcvd_settings = srv.assert_client_handshake().await; + assert_frame_eq(settings, rcvd_settings); + + srv.recv_frame( frames::headers(1) .request("GET", "https://example.com/") - .eos() + .eos(), ) - .send_frame(frames::headers(1).response(200).eos()) - .close(); + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; let mut builder = client::Builder::new(); builder.max_concurrent_streams(1); - let h2 = builder - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|(mut client, h2)| { - let request = Request::builder() - .method(Method::GET) - .uri("https://example.com/") - .body(()) - .unwrap(); - - let (response, _) = client.send_request(request, true).unwrap(); - h2.drive(response).map(move |(h2, _)| (client, h2)) - }); + let h2 = async move { + let (mut client, mut h2) = builder.handshake::<_, Bytes>(io).await.unwrap(); + let request = Request::builder() + .method(Method::GET) + .uri("https://example.com/") + .body(()) + .unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + }; - h2.join(srv).wait().expect("wait"); + join(srv, h2).await; } -#[test] -fn request_over_max_concurrent_streams_errors() { +#[tokio::test] +async fn request_over_max_concurrent_streams_errors() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - - - let srv = srv.assert_client_handshake_with_settings(frames::settings() - // super tiny server - .max_concurrent_streams(1)) - .unwrap() - .recv_settings() - .recv_frame( + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv + .assert_client_handshake_with_settings( + frames::settings() + // super tiny server + .max_concurrent_streams(1), + ) + .await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("POST", "https://example.com/") .eos(), ) - .send_frame(frames::headers(1).response(200).eos()) - .recv_frame(frames::headers(3).request("POST", "https://example.com/")) - .send_frame(frames::headers(3).response(200)) - .recv_frame(frames::data(3, "hello").eos()) - .send_frame(frames::data(3, "").eos()) - .recv_frame(frames::headers(5).request("POST", "https://example.com/")) - .send_frame(frames::headers(5).response(200)) - .recv_frame(frames::data(5, "hello").eos()) - .send_frame(frames::data(5, "").eos()) - .close(); - - let h2 = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, h2)| { - // we send a simple req here just to drive the connection so we can - // receive the server settings. - let request = Request::builder() - .method(Method::POST) - .uri("https://example.com/") - .body(()) - .unwrap(); + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + srv.recv_frame(frames::headers(3).request("POST", "https://example.com/")) + .await; + srv.send_frame(frames::headers(3).response(200)).await; + srv.recv_frame(frames::data(3, "hello").eos()).await; + srv.send_frame(frames::data(3, "").eos()).await; + srv.recv_frame(frames::headers(5).request("POST", "https://example.com/")) + .await; + srv.send_frame(frames::headers(5).response(200)).await; + srv.recv_frame(frames::data(5, "hello").eos()).await; + srv.send_frame(frames::data(5, "").eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.expect("handshake"); + // we send a simple req here just to drive the connection so we can + // receive the server settings. + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + // first request is allowed + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); - // first request is allowed - let (response, _) = client.send_request(request, true).unwrap(); - h2.drive(response).map(move |(h2, _)| (client, h2)) - }) - .and_then(|(mut client, h2)| { - let request = Request::builder() - .method(Method::POST) - .uri("https://example.com/") - .body(()) - .unwrap(); + // first request is allowed + let (resp1, mut stream1) = client.send_request(request, false).unwrap(); - // first request is allowed - let (resp1, mut stream1) = client.send_request(request, false).unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); - let request = Request::builder() - .method(Method::POST) - .uri("https://example.com/") - .body(()) - .unwrap(); + // second request is put into pending_open + let (resp2, mut stream2) = client.send_request(request, false).unwrap(); - // second request is put into pending_open - let (resp2, mut stream2) = client.send_request(request, false).unwrap(); + let request = Request::builder() + .method(Method::GET) + .uri("https://example.com/") + .body(()) + .unwrap(); - let request = Request::builder() - .method(Method::GET) - .uri("https://example.com/") - .body(()) - .unwrap(); + let waker = futures::task::noop_waker(); + let mut cx = Context::from_waker(&waker); - // third stream is over max concurrent - assert!(client.poll_ready().expect("poll_ready").is_not_ready()); + // third stream is over max concurrent + assert!(!client.poll_ready(&mut cx).is_ready()); - let err = client.send_request(request, true).unwrap_err(); - assert_eq!(err.to_string(), "user error: rejected"); + let err = client.send_request(request, true).unwrap_err(); + assert_eq!(err.to_string(), "user error: rejected"); - stream1.send_data("hello".into(), true).expect("req send_data"); + stream1 + .send_data("hello".into(), true) + .expect("req send_data"); - h2.drive(resp1.expect("req")).and_then(move |(h2, _)| { - stream2.send_data("hello".into(), true) - .expect("req2 send_data"); - h2.expect("h2").join(resp2.expect("req2")) - }) - }); + h2.drive(async move { + resp1.await.expect("req"); + stream2 + .send_data("hello".into(), true) + .expect("req2 send_data"); + }) + .await; + join(async move { h2.await.unwrap() }, async move { + resp2.await.unwrap() + }) + .await; + }; - h2.join(srv).wait().expect("wait"); + join(srv, h2).await; } -#[test] -fn send_request_poll_ready_when_connection_error() { +#[tokio::test] +async fn send_request_poll_ready_when_connection_error() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - - - let srv = srv.assert_client_handshake_with_settings(frames::settings() - // super tiny server - .max_concurrent_streams(1)) - .unwrap() - .recv_settings() - .recv_frame( + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv + .assert_client_handshake_with_settings( + frames::settings() + // super tiny server + .max_concurrent_streams(1), + ) + .await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("POST", "https://example.com/") .eos(), ) - .send_frame(frames::headers(1).response(200).eos()) - .recv_frame(frames::headers(3).request("POST", "https://example.com/").eos()) - .send_frame(frames::headers(8).response(200).eos()) - //.recv_frame(frames::headers(5).request("POST", "https://example.com/").eos()) - .close(); - - let h2 = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, h2)| { - // we send a simple req here just to drive the connection so we can - // receive the server settings. - let request = Request::builder() - .method(Method::POST) - .uri("https://example.com/") - .body(()) - .unwrap(); - - // first request is allowed - let (response, _) = client.send_request(request, true).unwrap(); - h2.drive(response).map(move |(h2, _)| (client, h2)) - }) - .and_then(|(mut client, h2)| { - let request = Request::builder() - .method(Method::POST) - .uri("https://example.com/") - .body(()) - .unwrap(); + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + srv.recv_frame( + frames::headers(3) + .request("POST", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(8).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.expect("handshake"); + // we send a simple req here just to drive the connection so we can + // receive the server settings. + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); - // first request is allowed - let (resp1, _) = client.send_request(request, true).unwrap(); + // first request is allowed + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); - let request = Request::builder() - .method(Method::POST) - .uri("https://example.com/") - .body(()) - .unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); - // second request is put into pending_open - let (resp2, _) = client.send_request(request, true).unwrap(); - - // third stream is over max concurrent - let until_ready = futures::future::poll_fn(move || { - client.poll_ready() - }).expect_err("client poll_ready").then(|_| Ok(())); - - // a FuturesUnordered is used on purpose! - // - // We don't want a join, since any of the other futures notifying - // will make the until_ready future polled again, but we are - // specifically testing that until_ready gets notified on its own. - let mut unordered = futures::stream::FuturesUnordered::>>::new(); - unordered.push(Box::new(until_ready)); - unordered.push(Box::new(h2.expect_err("client conn").then(|_| Ok(())))); - unordered.push(Box::new(resp1.expect_err("req1").then(|_| Ok(())))); - unordered.push(Box::new(resp2.expect_err("req2").then(|_| Ok(())))); - - unordered.for_each(|_| Ok(())) - }); + // first request is allowed + let (resp1, _) = client.send_request(request, true).unwrap(); - h2.join(srv).wait().expect("wait"); + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + + // second request is put into pending_open + let (resp2, _) = client.send_request(request, true).unwrap(); + + // third stream is over max concurrent + let until_ready = async move { + poll_fn(move |cx| client.poll_ready(cx)) + .await + .expect_err("client poll_ready"); + }; + + // a FuturesUnordered is used on purpose! + // + // We don't want a join, since any of the other futures notifying + // will make the until_ready future polled again, but we are + // specifically testing that until_ready gets notified on its own. + let mut unordered = + futures::stream::FuturesUnordered::>>>::new(); + unordered.push(Box::pin(until_ready)); + unordered.push(Box::pin(async move { + h2.await.expect_err("client conn"); + })); + unordered.push(Box::pin(async move { + resp1.await.expect_err("req1"); + })); + unordered.push(Box::pin(async move { + resp2.await.expect_err("req2"); + })); + + while let Some(_) = unordered.next().await {} + }; + + join(srv, h2).await; } -#[test] -fn send_reset_notifies_recv_stream() { +#[tokio::test] +async fn send_reset_notifies_recv_stream() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - - - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( - frames::headers(1) - .request("POST", "https://example.com/") - ) - .send_frame(frames::headers(1).response(200)) - .recv_frame(frames::reset(1).refused()) - .recv_frame(frames::go_away(0)) - .recv_eof(); - - let client = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let request = Request::builder() - .method(Method::POST) - .uri("https://example.com/") - .body(()) - .unwrap(); - - // first request is allowed - let (resp1, tx) = client.send_request(request, false).unwrap(); - - conn.drive(resp1) - .map(move |(conn, res)| (client, conn, tx, res)) - }) - .and_then(|(client, conn, mut tx, res)| { - let tx = futures::future::poll_fn(move || { - tx.send_reset(h2::Reason::REFUSED_STREAM); - Ok(().into()) - }); - - - let rx = res - .into_body() - .for_each(|_| -> Result<(), _> { - unreachable!("no response body expected") - }); - // a FuturesUnordered is used on purpose! - // - // We don't want a join, since any of the other futures notifying - // will make the rx future polled again, but we are - // specifically testing that rx gets notified on its own. - let mut unordered = futures::stream::FuturesUnordered::>>::new(); - unordered.push(Box::new(rx.expect_err("RecvBody").then(|_| Ok(())))); - unordered.push(Box::new(tx)); - - conn.drive(unordered.for_each(|_| Ok(()))) - .and_then(move |(conn, _)| { - drop(client); // now let client gracefully goaway - conn.expect("client") - }) - }); - - client.join(srv).wait().expect("wait"); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.recv_frame(frames::reset(1).refused()).await; + srv.recv_frame(frames::go_away(0)).await; + srv.recv_eof().await; + }; + + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.expect("handshake"); + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + + // first request is allowed + let (resp1, mut tx) = client.send_request(request, false).unwrap(); + let res = conn.drive(resp1).await.unwrap(); + + let tx = async move { + tx.send_reset(h2::Reason::REFUSED_STREAM); + }; + let rx = async { + let mut body = res.into_body(); + body.next().await.unwrap().expect_err("RecvBody"); + }; + + // a FuturesUnordered is used on purpose! + // + // We don't want a join, since any of the other futures notifying + // will make the rx future polled again, but we are + // specifically testing that rx gets notified on its own. + let mut unordered = FuturesUnordered::>>>::new(); + unordered.push(Box::pin(rx)); + unordered.push(Box::pin(tx)); + + conn.drive(unordered.for_each(ready)).await; + drop(client); // now let client gracefully goaway + conn.await.expect("client"); + }; + + join(srv, client).await; } -#[test] -fn http_11_request_without_scheme_or_authority() { +#[tokio::test] +async fn http_11_request_without_scheme_or_authority() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( - frames::headers(1) - .request("GET", "/") - .scheme("http") - .eos(), - ) - .send_frame(frames::headers(1).response(200).eos()) - .close(); - - let h2 = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, h2)| { - // HTTP_11 request with just :path is allowed - let request = Request::builder() - .method(Method::GET) - .uri("/") - .body(()) - .unwrap(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("GET", "/").scheme("http").eos()) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.expect("handshake"); + + // HTTP_11 request with just :path is allowed + let request = Request::builder() + .method(Method::GET) + .uri("/") + .body(()) + .unwrap(); - let (response, _) = client.send_request(request, true).unwrap(); - h2.drive(response) - .map(move |(h2, _)| (client, h2)) - }); + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + }; - h2.join(srv).wait().expect("wait"); + join(srv, h2).await; } -#[test] -fn http_2_request_without_scheme_or_authority() { +#[tokio::test] +async fn http_2_request_without_scheme_or_authority() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .close(); - - let h2 = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, h2)| { - // HTTP_2 with only a :path is illegal, so this request should - // be rejected as a user error. - let request = Request::builder() - .version(Version::HTTP_2) - .method(Method::GET) - .uri("/") - .body(()) - .unwrap(); - - client - .send_request(request, true) - .expect_err("should be UserError"); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + }; + + let h2 = async move { + let (mut client, h2) = client::handshake(io).await.expect("handshake"); + + // HTTP_2 with only a :path is illegal, so this request should + // be rejected as a user error. + let request = Request::builder() + .version(Version::HTTP_2) + .method(Method::GET) + .uri("/") + .body(()) + .unwrap(); - h2.expect("h2").map(|ret| { - // Hold on to the `client` handle to avoid sending a GO_AWAY frame. - drop(client); - ret - }) - }); + client + .send_request(request, true) + .expect_err("should be UserError"); + let ret = h2.await.expect("h2"); + drop(client); + ret + }; - h2.join(srv).wait().expect("wait"); + join(srv, h2).await; } #[test] #[ignore] fn request_with_h1_version() {} -#[test] -fn request_with_connection_headers() { +#[tokio::test] +async fn request_with_connection_headers() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); // can't assert full handshake, since client never sends a request, and // thus never bothers to ack the settings... - let srv = srv.read_preface() - .unwrap() - .recv_frame(frames::settings()) + let srv = async move { + srv.read_preface().await.unwrap(); + srv.recv_frame(frames::settings()).await; // goaway is required to make sure the connection closes because // of no active streams - .recv_frame(frames::go_away(0)) - .close(); + srv.recv_frame(frames::go_away(0)).await; + }; let headers = vec![ ("connection", "foo"), @@ -519,355 +523,341 @@ fn request_with_connection_headers() { ("te", "boom"), ]; - let client = client::handshake(io) - .expect("handshake") - .and_then(move |(mut client, conn)| { - for (name, val) in headers { - let req = Request::builder() - .uri("https://http2.akamai.com/") - .header(name, val) - .body(()) - .unwrap(); - let err = client.send_request(req, true).expect_err(name); - - assert_eq!(err.to_string(), "user error: malformed headers"); - } - conn.unwrap() - }); + let client = async move { + let (mut client, conn) = client::handshake(io).await.expect("handshake"); - client.join(srv).wait().expect("wait"); + for (name, val) in headers { + let req = Request::builder() + .uri("https://http2.akamai.com/") + .header(name, val) + .body(()) + .unwrap(); + let err = client.send_request(req, true).expect_err(name); + assert_eq!(err.to_string(), "user error: malformed headers"); + } + drop(client); + conn.await.unwrap(); + }; + + join(srv, client).await; } -#[test] -fn connection_close_notifies_response_future() { +#[tokio::test] +async fn connection_close_notifies_response_future() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let (io, mut srv) = mock::new(); + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") .eos(), ) + .await; // don't send any response, just close - .close(); + }; - let client = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let request = Request::builder() - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); + let client = async move { + let (mut client, conn) = client::handshake(io).await.expect("handshake"); + + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); - let req = client + let req = async move { + let res = client .send_request(request, true) .expect("send_request1") .0 - .then(|res| { - let err = res.expect_err("response"); - assert_eq!( - err.to_string(), - "broken pipe" - ); - Ok(()) - }); - - conn.expect("conn").join(req) - }); - - client.join(srv).wait().expect("wait"); + .await; + let err = res.expect_err("response"); + assert_eq!(err.to_string(), "broken pipe"); + }; + join(async move { conn.await.expect("conn") }, req).await; + }; + + join(srv, client).await; } -#[test] -fn connection_close_notifies_client_poll_ready() { +#[tokio::test] +async fn connection_close_notifies_client_poll_ready() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") .eos(), ) - .close(); + .await; + }; - let client = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let request = Request::builder() - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.expect("handshake"); + + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); - let req = client + let req = async { + let res = client .send_request(request, true) .expect("send_request1") .0 - .then(|res| { - let err = res.expect_err("response"); - assert_eq!( - err.to_string(), - "broken pipe" - ); - Ok::<_, ()>(()) - }); - - conn.drive(req) - .and_then(move |(_conn, _)| { - let err = client.poll_ready().expect_err("poll_ready"); - assert_eq!( - err.to_string(), - "broken pipe" - ); - Ok(()) - }) - }); + .await; + let err = res.expect_err("response"); + assert_eq!(err.to_string(), "broken pipe"); + }; + + conn.drive(req).await; + + let err = poll_fn(move |cx| client.poll_ready(cx)) + .await + .expect_err("poll_ready"); + assert_eq!(err.to_string(), "broken pipe"); + }; - client.join(srv).wait().expect("wait"); + join(srv, client).await; } -#[test] -fn sending_request_on_closed_connection() { +#[tokio::test] +async fn sending_request_on_closed_connection() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") .eos(), ) - .send_frame(frames::headers(1).response(200).eos()) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; // a bad frame! - .send_frame(frames::headers(0).response(200).eos()) - .close(); + srv.send_frame(frames::headers(0).response(200).eos()).await; + }; - let h2 = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, h2)| { - let request = Request::builder() - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); + let h2 = async move { + let (mut client, h2) = client::handshake(io).await.expect("handshake"); + + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); - // first request works - let req = client + // first request works + let req = Box::pin(async { + client .send_request(request, true) .expect("send_request1") .0 - .expect("response1") - .map(|_| ()); - - // after finish request1, there should be a conn error - let h2 = h2.then(|res| { - res.expect_err("h2 error"); - Ok::<(), ()>(()) - }); - - h2.select(req) - .then(|res| match res { - Ok((_, next)) => next, - Err(_) => unreachable!("both selected futures cannot error"), - }) - .map(move |_| client) - }) - .and_then(|mut client| { - let poll_err = client.poll_ready().unwrap_err(); - let msg = "protocol error: unspecific protocol error detected"; - assert_eq!(poll_err.to_string(), msg); - - let request = Request::builder() - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); - let send_err = client.send_request(request, true).unwrap_err(); - assert_eq!(send_err.to_string(), msg); + .await + .expect("response1"); + }); - Ok(()) + // after finish request1, there should be a conn error + let h2 = Box::pin(async move { + h2.await.expect_err("h2 error"); }); - h2.join(srv).wait().expect("wait"); + match select(h2, req).await { + Either::Left((_, req)) => req.await, + Either::Right((_, _h2)) => unreachable!("Shouldn't happen"), // TODO: Is this correct? + }; + + let poll_err = poll_fn(|cx| client.poll_ready(cx)).await.unwrap_err(); + let msg = "protocol error: unspecific protocol error detected"; + assert_eq!(poll_err.to_string(), msg); + + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + let send_err = client.send_request(request, true).unwrap_err(); + assert_eq!(send_err.to_string(), msg); + }; + + join(srv, h2).await; } -#[test] -fn recv_too_big_headers() { +#[tokio::test] +async fn recv_too_big_headers() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_custom_settings( - frames::settings() - .max_header_list_size(10) - ) - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_frame_eq(settings, frames::settings().max_header_list_size(10)); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") .eos(), ) - .recv_frame( + .await; + srv.recv_frame( frames::headers(3) .request("GET", "https://http2.akamai.com/") .eos(), ) - .send_frame(frames::headers(1).response(200).eos()) - .send_frame(frames::headers(3).response(200)) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + srv.send_frame(frames::headers(3).response(200)).await; // no reset for 1, since it's closed anyways // but reset for 3, since server hasn't closed stream - .recv_frame(frames::reset(3).refused()) - .idle_ms(10) - .close(); - - let client = client::Builder::new() - .max_header_list_size(10) - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let request = Request::builder() - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); - - let req1 = client - .send_request(request, true) - .expect("send_request") - .0 - .expect_err("response1") - .map(|err| { - assert_eq!( - err.reason(), - Some(Reason::REFUSED_STREAM) - ); - }); - - let request = Request::builder() - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); - - let req2 = client - .send_request(request, true) - .expect("send_request") - .0 - .expect_err("response2") - .map(|err| { - assert_eq!( - err.reason(), - Some(Reason::REFUSED_STREAM) - ); - }); - - conn.drive(req1.join(req2)) - .and_then(|(conn, _)| conn.expect("client")) - .map(|c| (c, client)) - }); + srv.recv_frame(frames::reset(3).refused()).await; + idle_ms(10).await; + }; + + let client = async move { + let (mut client, mut conn) = client::Builder::new() + .max_header_list_size(10) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); - client.join(srv).wait().expect("wait"); + let req1 = client.send_request(request, true); + let req1 = async move { + let err = req1.expect("send_request").0.await.expect_err("response1"); + assert_eq!(err.reason(), Some(Reason::REFUSED_STREAM)); + }; + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + + let req2 = client.send_request(request, true); + let req2 = async move { + let err = req2.expect("send_request").0.await.expect_err("response2"); + assert_eq!(err.reason(), Some(Reason::REFUSED_STREAM)); + }; + + conn.drive(join(req1, req2)).await; + conn.await.expect("client"); + }; + join(srv, client).await; } -#[test] -fn pending_send_request_gets_reset_by_peer_properly() { +#[tokio::test] +async fn pending_send_request_gets_reset_by_peer_properly() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let payload = [0; (frame::DEFAULT_INITIAL_WINDOW_SIZE * 2) as usize]; + let payload = vec![0; (frame::DEFAULT_INITIAL_WINDOW_SIZE * 2) as usize]; let max_frame_size = frame::DEFAULT_MAX_FRAME_SIZE as usize; - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( - frames::headers(1) - .request("GET", "https://http2.akamai.com/"), - ) + let srv = async { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("GET", "https://http2.akamai.com/")) + .await; // Note that we can only send up to ~4 frames of data by default - .recv_frame(frames::data(1, &payload[0..max_frame_size])) - .recv_frame(frames::data(1, &payload[max_frame_size..(max_frame_size*2)])) - .recv_frame(frames::data(1, &payload[(max_frame_size*2)..(max_frame_size*3)])) - .recv_frame(frames::data(1, &payload[(max_frame_size*3)..(max_frame_size*4-1)])) + srv.recv_frame(frames::data(1, &payload[0..max_frame_size])) + .await; + srv.recv_frame(frames::data( + 1, + &payload[max_frame_size..(max_frame_size * 2)], + )) + .await; + srv.recv_frame(frames::data( + 1, + &payload[(max_frame_size * 2)..(max_frame_size * 3)], + )) + .await; + srv.recv_frame(frames::data( + 1, + &payload[(max_frame_size * 3)..(max_frame_size * 4 - 1)], + )) + .await; - .idle_ms(100) + idle_ms(100).await; - .send_frame(frames::reset(1).refused()) + srv.send_frame(frames::reset(1).refused()).await; // Because all active requests are finished, connection should shutdown // and send a GO_AWAY frame. If the reset stream is bugged (and doesn't // count towards concurrency limit), then connection will not send // a GO_AWAY and this test will fail. - .recv_frame(frames::go_away(0)) + srv.recv_frame(frames::go_away(0)).await; + drop(srv); + }; + + let client = async { + let (mut client, mut conn) = client::Builder::new() + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); - .close(); + let (response, mut stream) = client.send_request(request, false).expect("send_request"); - let client = client::Builder::new() - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let request = Request::builder() - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); + let response = async move { + let err = response.await.expect_err("response"); + assert_eq!(err.reason(), Some(Reason::REFUSED_STREAM)); + }; - let (response, mut stream) = client - .send_request(request, false) - .expect("send_request"); + // Send the data + stream.send_data(payload[..].into(), true).unwrap(); + conn.drive(response).await; + drop(client); + drop(stream); + conn.await.expect("client"); + }; - let response = response.expect_err("response") - .map(|err| { - assert_eq!( - err.reason(), - Some(Reason::REFUSED_STREAM) - ); - }); + join(srv, client).await; +} - // Send the data - stream.send_data(payload[..].into(), true).unwrap(); +#[tokio::test] +async fn request_without_path() { + let _ = env_logger::try_init(); + let (io, mut srv) = mock::new(); - conn.drive(response) - .and_then(|(conn, _)| conn.expect("client")) - }); + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); - client.join(srv).wait().expect("wait"); -} + srv.recv_frame( + frames::headers(1) + .request("GET", "http://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; -#[test] -fn request_without_path() { - let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame(frames::headers(1).request("GET", "http://example.com/").eos()) - .send_frame(frames::headers(1).response(200).eos()) - .close(); - - let client = client::handshake(io) - .expect("handshake") - .and_then(move |(mut client, conn)| { - // Note the lack of trailing slash. - let request = Request::get("http://example.com") - .body(()) - .unwrap(); + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.expect("handshake"); + // Note the lack of trailing slash. + let request = Request::get("http://example.com").body(()).unwrap(); - let (response, _) = client.send_request(request, true).unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); - conn.drive(response) - }); + conn.drive(response).await.unwrap(); + }; - client.join(srv).wait().unwrap(); + join(srv, client).await; } -#[test] -fn request_options_with_star() { +#[tokio::test] +async fn request_options_with_star() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); // Note the lack of trailing slash. let uri = uri::Uri::from_parts({ @@ -876,208 +866,190 @@ fn request_options_with_star() { parts.authority = Some(uri::Authority::from_shared("example.com".into()).unwrap()); parts.path_and_query = Some(uri::PathAndQuery::from_static("*")); parts - }).unwrap(); - - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame(frames::headers(1).request("OPTIONS", uri.clone()).eos()) - .send_frame(frames::headers(1).response(200).eos()) - .close(); - - let client = client::handshake(io) - .expect("handshake") - .and_then(move |(mut client, conn)| { - let request = Request::builder() - .method(Method::OPTIONS) - .uri(uri) - .body(()) - .unwrap(); + }) + .unwrap(); + + let uri_clone = uri.clone(); + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("OPTIONS", uri_clone).eos()) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; + + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.expect("handshake"); + let request = Request::builder() + .method(Method::OPTIONS) + .uri(uri) + .body(()) + .unwrap(); - let (response, _) = client.send_request(request, true).unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); - conn.drive(response) - }); + conn.drive(response).await.unwrap(); + }; - client.join(srv).wait().unwrap(); + join(srv, client).await; } -#[test] -fn notify_on_send_capacity() { +#[tokio::test] +async fn notify_on_send_capacity() { // This test ensures that the client gets notified when there is additional // send capacity. In other words, when the server is ready to accept a new // stream, the client is notified. - use std::sync::mpsc; + use futures::channel::oneshot; let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - let (done_tx, done_rx) = futures::sync::oneshot::channel(); - let (tx, rx) = mpsc::channel(); + let (io, mut srv) = mock::new(); + let (done_tx, done_rx) = oneshot::channel(); + let (tx, rx) = oneshot::channel(); let mut settings = frame::Settings::default(); settings.set_max_concurrent_streams(Some(1)); - let srv = srv - .assert_client_handshake_with_settings(settings) - .unwrap() + let srv = async move { + let settings = srv.assert_client_handshake_with_settings(settings).await; // This is the ACK - .recv_settings() - .map(move |h| { - tx.send(()).unwrap(); - h - }) - .recv_frame( + assert_default_settings!(settings); + tx.send(()).unwrap(); + srv.recv_frame( frames::headers(1) .request("GET", "https://www.example.com/") .eos(), ) - .send_frame(frames::headers(1).response(200).eos()) - .recv_frame( + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + srv.recv_frame( frames::headers(3) .request("GET", "https://www.example.com/") .eos(), ) - .send_frame(frames::headers(3).response(200).eos()) - .recv_frame( + .await; + srv.send_frame(frames::headers(3).response(200).eos()).await; + srv.recv_frame( frames::headers(5) .request("GET", "https://www.example.com/") .eos(), ) - .send_frame(frames::headers(5).response(200).eos()) + .await; + srv.send_frame(frames::headers(5).response(200).eos()).await; // Don't close the connection until the client is done doing its // checks. - .wait_for(done_rx) - .close() - ; - - let client = client::handshake(io) - .expect("handshake") - .and_then(move |(mut client, conn)| { - ::std::thread::spawn(move || { - rx.recv().unwrap(); + done_rx.await.unwrap(); + }; - let mut responses = vec![]; + let client = async move { + let (mut client, conn) = client::handshake(io).await.expect("handshake"); + tokio::spawn(async move { + rx.await.unwrap(); - for _ in 0..3 { - // Wait for capacity. If the client is **not** notified, - // this hangs. - poll_fn(|| client.poll_ready()).wait().unwrap(); + let mut responses = vec![]; - let request = Request::builder() - .uri("https://www.example.com/") - .body(()) - .unwrap(); + for _ in 0..3usize { + // Wait for capacity. If the client is **not** notified, + // this hangs. + poll_fn(|cx| client.poll_ready(cx)).await.unwrap(); - let response = client.send_request(request, true) - .unwrap().0; + let request = Request::builder() + .uri("https://www.example.com/") + .body(()) + .unwrap(); - responses.push(response); - } + let response = client.send_request(request, true).unwrap().0; - for response in responses { - let response = response.wait().unwrap(); - assert_eq!(response.status(), StatusCode::OK); - } + responses.push(response); + } - poll_fn(|| client.poll_ready()).wait().unwrap(); + for response in responses { + let response = response.await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + } - done_tx.send(()).unwrap(); - }); + poll_fn(|cx| client.poll_ready(cx)).await.unwrap(); - conn.expect("h2") - }) - .expect("client"); + done_tx.send(()).unwrap(); + }); + conn.await.expect("h2"); + }; - client.join(srv).wait().unwrap(); + join(srv, client).await; } -#[test] -fn send_stream_poll_reset() { +#[tokio::test] +async fn send_stream_poll_reset() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - - let srv = srv - .assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( - frames::headers(1) - .request("POST", "https://example.com/") - ) - .send_frame(frames::reset(1).refused()) - .close(); - - let client = client::Builder::new() - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let request = Request::builder() - .method(Method::POST) - .uri("https://example.com/") - .body(()) - .unwrap(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://example.com/")) + .await; + srv.send_frame(frames::reset(1).refused()).await; + }; + + let client = async move { + let (mut client, mut conn) = client::Builder::new() + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); - let (_response, mut tx) = client.send_request(request, false).unwrap(); - conn.drive(futures::future::poll_fn(move || { - tx.poll_reset() - })) - .map(|(_, reason)| { - assert_eq!(reason, Reason::REFUSED_STREAM); - }) - }); + let (_response, mut tx) = client.send_request(request, false).unwrap(); + let reason = conn + .drive(poll_fn(move |cx| tx.poll_reset(cx))) + .await + .unwrap(); + assert_eq!(reason, Reason::REFUSED_STREAM); + }; - client.join(srv).wait().expect("wait"); + join(srv, client).await; } -#[test] -fn drop_pending_open() { +#[tokio::test] +async fn drop_pending_open() { // This test checks that a stream queued for pending open behaves correctly when its // client drops. let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - let (init_tx, init_rx) = futures::sync::oneshot::channel(); - let (trigger_go_away_tx, trigger_go_away_rx) = futures::sync::oneshot::channel(); - let (sent_go_away_tx, sent_go_away_rx) = futures::sync::oneshot::channel(); - let (drop_tx, drop_rx) = futures::sync::oneshot::channel(); + let (io, mut srv) = mock::new(); + let (init_tx, init_rx) = futures::channel::oneshot::channel(); + let (trigger_go_away_tx, trigger_go_away_rx) = futures::channel::oneshot::channel(); + let (sent_go_away_tx, sent_go_away_rx) = futures::channel::oneshot::channel(); + let (drop_tx, drop_rx) = futures::channel::oneshot::channel(); let mut settings = frame::Settings::default(); settings.set_max_concurrent_streams(Some(2)); - let srv = srv - .assert_client_handshake_with_settings(settings) - .unwrap() - // This is the ACK - .recv_settings() - .map(move |h| { - init_tx.send(()).unwrap(); - h - }) - .recv_frame( - frames::headers(1) - .request("GET", "https://www.example.com/"), - ) - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake_with_settings(settings).await; + // This is the ACK + assert_default_settings!(settings); + init_tx.send(()).unwrap(); + srv.recv_frame(frames::headers(1).request("GET", "https://www.example.com/")) + .await; + srv.recv_frame( frames::headers(3) .request("GET", "https://www.example.com/") .eos(), ) - .wait_for(trigger_go_away_rx) - .send_frame(frames::go_away(3)) - .map(move |h| { - sent_go_away_tx.send(()).unwrap(); - h - }) - .wait_for(drop_rx) - .send_frame(frames::headers(3).response(200).eos()) - .recv_frame( - frames::data(1, vec![]).eos(), - ) - .send_frame(frames::headers(1).response(200).eos()) - .close() - ; + .await; + trigger_go_away_rx.await.unwrap(); + srv.send_frame(frames::go_away(3)).await; + sent_go_away_tx.send(()).unwrap(); + drop_rx.await.unwrap(); + srv.send_frame(frames::headers(3).response(200).eos()).await; + srv.recv_frame(frames::data(1, vec![]).eos()).await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; fn request() -> Request<()> { Request::builder() @@ -1086,80 +1058,83 @@ fn drop_pending_open() { .unwrap() } - let client = client::Builder::new() - .max_concurrent_reset_streams(0) - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(move |(mut client, conn)| { - conn.expect("h2").join(init_rx.expect("init_rx").and_then(move |()| { - // Fill up the concurrent stream limit. - assert!(client.poll_ready().unwrap().is_ready()); - let mut response1 = client.send_request(request(), false).unwrap(); - assert!(client.poll_ready().unwrap().is_ready()); - let response2 = client.send_request(request(), true).unwrap(); - assert!(client.poll_ready().unwrap().is_ready()); - let response3 = client.send_request(request(), true).unwrap(); - - // Trigger a GOAWAY frame to invalidate our third request. - trigger_go_away_tx.send(()).unwrap(); - sent_go_away_rx.expect("sent_go_away_rx").and_then(move |_| { - // Now drop all the references to that stream. - drop(response3); - drop(client); - drop_tx.send(()).unwrap(); - - // Complete the second request, freeing up a stream. - response2.0.expect("resp2") - }).and_then(move |_| { - response1.1.send_data(Default::default(), true).unwrap(); - response1.0.expect("resp1") - }) - })) - }); + let client = async move { + let (mut client, conn) = client::Builder::new() + .max_concurrent_reset_streams(0) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + let f = async move { + init_rx.await.expect("init_rx"); + // Fill up the concurrent stream limit. + poll_fn(|cx| client.poll_ready(cx)).await.unwrap(); + let mut response1 = client.send_request(request(), false).unwrap(); + poll_fn(|cx| client.poll_ready(cx)).await.unwrap(); + let response2 = client.send_request(request(), true).unwrap(); + poll_fn(|cx| client.poll_ready(cx)).await.unwrap(); + let response3 = client.send_request(request(), true).unwrap(); + + // Trigger a GOAWAY frame to invalidate our third request. + trigger_go_away_tx.send(()).unwrap(); + sent_go_away_rx.await.expect("sent_go_away_rx"); + // Now drop all the references to that stream. + drop(response3); + drop(client); + drop_tx.send(()).unwrap(); + // Complete the second request, freeing up a stream. + response2.0.await.expect("resp2"); + response1.1.send_data(Default::default(), true).unwrap(); + response1.0.await.expect("resp1") + }; + + join( + async move { + conn.await.expect("h2"); + }, + f, + ) + .await; + }; - client.join(srv).wait().unwrap(); + join(srv, client).await; } -#[test] -fn malformed_response_headers_dont_unlink_stream() { +#[tokio::test] +async fn malformed_response_headers_dont_unlink_stream() { // This test checks that receiving malformed headers frame on a stream with // no remaining references correctly resets the stream, without prematurely // unlinking it. let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - let (drop_tx, drop_rx) = futures::sync::oneshot::channel(); - let (queued_tx, queued_rx) = futures::sync::oneshot::channel(); - - let srv = srv - .assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame(frames::headers(1).request("GET", "http://example.com/")) - .recv_frame(frames::headers(3).request("GET", "http://example.com/")) - .recv_frame(frames::headers(5).request("GET", "http://example.com/")) - .map(move |h| { - drop_tx.send(()).unwrap(); - h - }) - .wait_for(queued_rx) - .send_bytes(&[ + let (io, mut srv) = mock::new(); + let (drop_tx, drop_rx) = futures::channel::oneshot::channel(); + let (queued_tx, queued_rx) = futures::channel::oneshot::channel(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + + srv.recv_frame(frames::headers(1).request("GET", "http://example.com/")) + .await; + srv.recv_frame(frames::headers(3).request("GET", "http://example.com/")) + .await; + srv.recv_frame(frames::headers(5).request("GET", "http://example.com/")) + .await; + drop_tx.send(()).unwrap(); + queued_rx.await.unwrap(); + srv.send_bytes(&[ // 2 byte frame - 0, 0, 2, - // type: HEADERS - 1, - // flags: END_STREAM | END_HEADERS - 5, - // stream identifier: 3 - 0, 0, 0, 3, - // data - invalid (pseudo not at end of block) - 144, 135 - // Per the spec, this frame should cause a stream error of type - // PROTOCOL_ERROR. + 0, 0, 2, // type: HEADERS + 1, // flags: END_STREAM | END_HEADERS + 5, // stream identifier: 3 + 0, 0, 0, 3, // data - invalid (pseudo not at end of block) + 144, + 135, // Per the spec, this frame should cause a stream error of type + // PROTOCOL_ERROR. ]) - .close() - ; + .await; + }; fn request() -> Request<()> { Request::builder() @@ -1168,30 +1143,32 @@ fn malformed_response_headers_dont_unlink_stream() { .unwrap() } - let client = client::Builder::new() - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(move |(mut client, conn)| { - let (_req1, mut send1) = client.send_request( - request(), false).unwrap(); - // Use up most of the connection window. - send1.send_data(vec![0; 65534].into(), true).unwrap(); - let (req2, mut send2) = client.send_request( - request(), false).unwrap(); - let (req3, mut send3) = client.send_request( - request(), false).unwrap(); - conn.expect("h2").join(drop_rx.then(move |_| { - // Use up the remainder of the connection window. - send2.send_data(vec![0; 2].into(), true).unwrap(); - // Queue up for more connection window. - send3.send_data(vec![0; 1].into(), true).unwrap(); - queued_tx.send(()).unwrap(); - Ok((req2, req3)) - })) - }); - - - client.join(srv).wait().unwrap(); + let client = async move { + let (mut client, conn) = client::Builder::new() + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let (_req1, mut send1) = client.send_request(request(), false).unwrap(); + // Use up most of the connection window. + send1.send_data(vec![0; 65534].into(), true).unwrap(); + let (req2, mut send2) = client.send_request(request(), false).unwrap(); + let (req3, mut send3) = client.send_request(request(), false).unwrap(); + + let f = async move { + drop_rx.await.unwrap(); + // Use up the remainder of the connection window. + send2.send_data(vec![0; 2].into(), true).unwrap(); + // Queue up for more connection window. + send3.send_data(vec![0; 1].into(), true).unwrap(); + queued_tx.send(()).unwrap(); + drop((req2, req3)); + }; + + join(async move { conn.await.expect("h2") }, f).await; + }; + + join(srv, client).await; } const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; diff --git a/tests/h2-tests/tests/codec_read.rs b/tests/h2-tests/tests/codec_read.rs index 631db7565..e455cfc81 100644 --- a/tests/h2-tests/tests/codec_read.rs +++ b/tests/h2-tests/tests/codec_read.rs @@ -1,9 +1,11 @@ -use h2_support::prelude::*; +#![feature(async_await)] +use futures::future::join; +use h2_support::prelude::*; use std::error::Error; -#[test] -fn read_none() { +#[tokio::test] +async fn read_none() { let mut codec = Codec::from(mock_io::Builder::new().build()); assert_closed!(codec); @@ -15,8 +17,8 @@ fn read_frame_too_big() {} // ===== DATA ===== -#[test] -fn read_data_no_padding() { +#[tokio::test] +async fn read_data_no_padding() { let mut codec = raw_codec! { read => [ 0, 0, 5, 0, 0, 0, 0, 0, 1, @@ -32,8 +34,8 @@ fn read_data_no_padding() { assert_closed!(codec); } -#[test] -fn read_data_empty_payload() { +#[tokio::test] +async fn read_data_empty_payload() { let mut codec = raw_codec! { read => [ 0, 0, 0, 0, 0, 0, 0, 0, 1, @@ -48,8 +50,8 @@ fn read_data_empty_payload() { assert_closed!(codec); } -#[test] -fn read_data_end_stream() { +#[tokio::test] +async fn read_data_end_stream() { let mut codec = raw_codec! { read => [ 0, 0, 5, 0, 1, 0, 0, 0, 1, @@ -61,12 +63,11 @@ fn read_data_end_stream() { assert_eq!(data.stream_id(), 1); assert_eq!(data.payload(), &b"hello"[..]); assert!(data.is_end_stream()); - assert_closed!(codec); } -#[test] -fn read_data_padding() { +#[tokio::test] +async fn read_data_padding() { let mut codec = raw_codec! { read => [ 0, 0, 16, 0, 0x8, 0, 0, 0, 1, @@ -84,8 +85,8 @@ fn read_data_padding() { assert_closed!(codec); } -#[test] -fn read_push_promise() { +#[tokio::test] +async fn read_push_promise() { let mut codec = raw_codec! { read => [ 0, 0, 0x5, @@ -104,8 +105,8 @@ fn read_push_promise() { assert_closed!(codec); } -#[test] -fn read_data_stream_id_zero() { +#[tokio::test] +async fn read_data_stream_id_zero() { let mut codec = raw_codec! { read => [ 0, 0, 5, 0, 0, 0, 0, 0, 0, @@ -130,64 +131,70 @@ fn read_headers_with_pseudo() {} #[ignore] fn read_headers_empty_payload() {} -#[test] -fn read_continuation_frames() { +#[tokio::test] +async fn read_continuation_frames() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); let large = build_large_headers(); - let frame = large.iter().fold( - frames::headers(1).response(200), - |frame, &(name, ref value)| frame.field(name, &value[..]), - ).eos(); - - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let frame = large + .iter() + .fold( + frames::headers(1).response(200), + |frame, &(name, ref value)| frame.field(name, &value[..]), + ) + .eos(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") .eos(), ) - .send_frame(frame) - .close(); - - let client = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let request = Request::builder() - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); - - let req = client + .await; + srv.send_frame(frame).await; + }; + + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.expect("handshake"); + + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + + let req = async { + let res = client .send_request(request, true) .expect("send_request") .0 - .expect("response") - .map(move |res| { - assert_eq!(res.status(), StatusCode::OK); - let (head, _body) = res.into_parts(); - let expected = large.iter().fold(HeaderMap::new(), |mut map, &(name, ref value)| { - use h2_support::frames::HttpTryInto; - map.append(name, value.as_str().try_into().unwrap()); - map - }); - assert_eq!(head.headers, expected); + .await + .expect("response"); + assert_eq!(res.status(), StatusCode::OK); + let (head, _body) = res.into_parts(); + let expected = large + .iter() + .fold(HeaderMap::new(), |mut map, &(name, ref value)| { + use h2_support::frames::HttpTryInto; + map.append(name, value.as_str().try_into().unwrap()); + map }); + assert_eq!(head.headers, expected); + }; - conn.drive(req) - .and_then(move |(h2, _)| { - h2.expect("client") - }).map(|c| (client, c)) - }); - - client.join(srv).wait().expect("wait"); + conn.drive(req).await; + conn.await.expect("client"); + }; + join(srv, client).await; } -#[test] -fn update_max_frame_len_at_rest() { +#[tokio::test] +async fn update_max_frame_len_at_rest() { + use futures::StreamExt; + let _ = env_logger::try_init(); // TODO: add test for updating max frame length in flight as well? let mut codec = raw_codec! { @@ -205,7 +212,7 @@ fn update_max_frame_len_at_rest() { assert_eq!(codec.max_recv_frame_size(), 16_384); assert_eq!( - codec.poll().unwrap_err().description(), + codec.next().await.unwrap().unwrap_err().description(), "frame with invalid size" ); } diff --git a/tests/h2-tests/tests/codec_write.rs b/tests/h2-tests/tests/codec_write.rs index dbb9e62a7..5537ba0b4 100644 --- a/tests/h2-tests/tests/codec_write.rs +++ b/tests/h2-tests/tests/codec_write.rs @@ -1,61 +1,55 @@ +#![feature(async_await)] + +use futures::future::join; use h2_support::prelude::*; -#[test] -fn write_continuation_frames() { +#[tokio::test] +async fn write_continuation_frames() { // An invalid dependency ID results in a stream level error. The hpack // payload should still be decoded. let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); let large = build_large_headers(); // Build the large request frame let frame = large.iter().fold( frames::headers(1).request("GET", "https://http2.akamai.com/"), - |frame, &(name, ref value)| frame.field(name, &value[..])); - - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame(frame.eos()) - .send_frame( - frames::headers(1) - .response(204) - .eos(), - ) - .close(); - - let client = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let mut request = Request::builder(); - request.uri("https://http2.akamai.com/"); - - for &(name, ref value) in &large { - request.header(name, &value[..]); - } - - let request = request - .body(()) - .unwrap(); - - let req = client + |frame, &(name, ref value)| frame.field(name, &value[..]), + ); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frame.eos()).await; + srv.send_frame(frames::headers(1).response(204).eos()).await; + }; + + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.expect("handshake"); + + let mut request = Request::builder(); + request.uri("https://http2.akamai.com/"); + + for &(name, ref value) in &large { + request.header(name, &value[..]); + } + + let request = request.body(()).unwrap(); + + let req = async { + let res = client .send_request(request, true) .expect("send_request1") .0 - .then(|res| { - let response = res.unwrap(); - assert_eq!(response.status(), StatusCode::NO_CONTENT); - Ok::<_, ()>(()) - }); - - conn.drive(req) - .and_then(move |(h2, _)| { - h2.unwrap() - }).map(|c| { - (c, client) - }) - }); - - client.join(srv).wait().expect("wait"); + .await; + let response = res.unwrap(); + assert_eq!(response.status(), StatusCode::NO_CONTENT); + }; + + conn.drive(req).await; + conn.await.unwrap(); + }; + + join(srv, client).await; } diff --git a/tests/h2-tests/tests/flow_control.rs b/tests/h2-tests/tests/flow_control.rs index 5eedd504a..9cccf8dd6 100644 --- a/tests/h2-tests/tests/flow_control.rs +++ b/tests/h2-tests/tests/flow_control.rs @@ -1,9 +1,13 @@ +#![feature(async_await)] +use futures::future::{join, join4}; +use futures::{StreamExt, TryStreamExt}; use h2_support::prelude::*; +use h2_support::util::yield_once; // In this case, the stream & connection both have capacity, but capacity is not // explicitly requested. -#[test] -fn send_data_without_requesting_capacity() { +#[tokio::test] +async fn send_data_without_requesting_capacity() { let _ = env_logger::try_init(); let payload = [0; 1024]; @@ -12,8 +16,8 @@ fn send_data_without_requesting_capacity() { .handshake() .write(&[ // POST / - 0, 0, 16, 1, 4, 0, 0, 0, 1, 131, 135, 65, 139, 157, 41, - 172, 75, 143, 168, 233, 25, 151, 33, 233, 132, + 0, 0, 16, 1, 4, 0, 0, 0, 1, 131, 135, 65, 139, 157, 41, 172, 75, 143, 168, 233, 25, 151, + 33, 233, 132, ]) .write(&[ // DATA @@ -25,7 +29,7 @@ fn send_data_without_requesting_capacity() { .read(&[0, 0, 1, 1, 5, 0, 0, 0, 1, 0x89]) .build(); - let (mut client, mut h2) = client::handshake(mock).wait().unwrap(); + let (mut client, mut h2) = client::handshake(mock).await.unwrap(); let request = Request::builder() .method(Method::POST) @@ -42,140 +46,126 @@ fn send_data_without_requesting_capacity() { stream.send_data(payload[..].into(), true).unwrap(); // Get the response - let resp = h2.run(response).unwrap(); + let resp = h2.run(response).await.unwrap(); assert_eq!(resp.status(), StatusCode::NO_CONTENT); - h2.wait().unwrap(); + h2.await.unwrap(); } -#[test] -fn release_capacity_sends_window_update() { +#[tokio::test] +async fn release_capacity_sends_window_update() { let _ = env_logger::try_init(); let payload = vec![0u8; 16_384]; + let payload_len = payload.len(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let mock = srv.assert_client_handshake().unwrap() - .recv_settings() - .recv_frame( + let mock = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") - .eos() - ) - .send_frame( - frames::headers(1) - .response(200) + .eos(), ) - .send_frame(frames::data(1, &payload[..])) - .send_frame(frames::data(1, &payload[..])) - .send_frame(frames::data(1, &payload[..])) - .recv_frame( - frames::window_update(0, 32_768) - ) - .recv_frame( - frames::window_update(1, 32_768) - ) - .send_frame(frames::data(1, &payload[..]).eos()) - // gotta end the connection - .map(drop); - - let h2 = client::handshake(io).unwrap().and_then(|(mut client, h2)| { + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.send_frame(frames::data(1, &payload[..])).await; + srv.send_frame(frames::data(1, &payload[..])).await; + srv.send_frame(frames::data(1, &payload[..])).await; + srv.recv_frame(frames::window_update(0, 32_768)).await; + srv.recv_frame(frames::window_update(1, 32_768)).await; + srv.send_frame(frames::data(1, &payload[..]).eos()).await; + }; + + let h2 = async move { + let (mut client, h2) = client::handshake(io).await.unwrap(); let request = Request::builder() .method(Method::GET) .uri("https://http2.akamai.com/") .body(()) .unwrap(); - let req = client.send_request(request, true).unwrap() - .0 - .unwrap() - // Get the response - .and_then(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - let body = resp.into_parts().1; - body.into_future().unwrap() - }) + let req = async move { + let resp = client.send_request(request, true).unwrap().0.await.unwrap(); + // Get the response + assert_eq!(resp.status(), StatusCode::OK); + let mut body = resp.into_parts().1; - // read some body to use up window size to below half - .and_then(|(buf, body)| { - assert_eq!(buf.unwrap().len(), payload.len()); - body.into_future().unwrap() - }) - .and_then(|(buf, body)| { - assert_eq!(buf.unwrap().len(), payload.len()); - body.into_future().unwrap() - }) - .and_then(|(buf, mut body)| { - let buf = buf.unwrap(); - assert_eq!(buf.len(), payload.len()); - body.release_capacity().release_capacity(buf.len() * 2).unwrap(); - body.into_future().unwrap() - }) - .and_then(|(buf, _)| { - assert_eq!(buf.unwrap().len(), payload.len()); - Ok(()) - }); + // read some body to use up window size to below half + let buf = body.next().await.unwrap().unwrap(); + assert_eq!(buf.len(), payload_len); - h2.unwrap().join(req) - }); - h2.join(mock).wait().unwrap(); + let buf = body.next().await.unwrap().unwrap(); + assert_eq!(buf.len(), payload_len); + + let buf = body.next().await.unwrap().unwrap(); + assert_eq!(buf.len(), payload_len); + body.release_capacity() + .release_capacity(buf.len() * 2) + .unwrap(); + + let buf = body.next().await.unwrap().unwrap(); + assert_eq!(buf.len(), payload_len); + }; + + join( + async move { + h2.await.unwrap(); + }, + req, + ) + .await + }; + join(mock, h2).await; } -#[test] -fn release_capacity_of_small_amount_does_not_send_window_update() { +#[tokio::test] +async fn release_capacity_of_small_amount_does_not_send_window_update() { let _ = env_logger::try_init(); let payload = [0; 16]; - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let mock = srv.assert_client_handshake().unwrap() - .recv_settings() - .recv_frame( + let mock = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") - .eos() + .eos(), ) - .send_frame( - frames::headers(1) - .response(200) - ) - .send_frame(frames::data(1, &payload[..]).eos()) - // gotta end the connection - .map(drop); + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.send_frame(frames::data(1, &payload[..]).eos()).await; + }; - let h2 = client::handshake(io).unwrap().and_then(|(mut client, h2)| { + let h2 = async move { + let (mut client, h2) = client::handshake(io).await.unwrap(); let request = Request::builder() .method(Method::GET) .uri("https://http2.akamai.com/") .body(()) .unwrap(); - let req = client.send_request(request, true).unwrap() - .0 - .unwrap() - // Get the response - .and_then(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - let body = resp.into_parts().1; - assert!(!body.is_end_stream()); - body.into_future().unwrap() - }) - // read the small body and then release it - .and_then(|(buf, mut body)| { - let buf = buf.unwrap(); - assert_eq!(buf.len(), 16); - body.release_capacity().release_capacity(buf.len()).unwrap(); - body.into_future().unwrap() - }) - .and_then(|(buf, _)| { - assert!(buf.is_none()); - Ok(()) - }); - h2.unwrap().join(req) - }); - h2.join(mock).wait().unwrap(); + let req = async move { + let resp = client.send_request(request, true).unwrap().0.await.unwrap(); + // Get the response + assert_eq!(resp.status(), StatusCode::OK); + let mut body = resp.into_parts().1; + assert!(!body.is_end_stream()); + let buf = body.next().await.unwrap().unwrap(); + // read the small body and then release it + assert_eq!(buf.len(), 16); + body.release_capacity().release_capacity(buf.len()).unwrap(); + let buf = body.next().await; + assert!(buf.is_none()); + }; + join(async move { h2.await.unwrap() }, req).await; + }; + join(mock, h2).await; } #[test] @@ -186,133 +176,119 @@ fn expand_window_sends_window_update() {} #[ignore] fn expand_window_calls_are_coalesced() {} -#[test] -fn recv_data_overflows_connection_window() { +#[tokio::test] +async fn recv_data_overflows_connection_window() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let mock = srv.assert_client_handshake().unwrap() - .recv_settings() - .recv_frame( + let mock = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") - .eos() - ) - .send_frame( - frames::headers(1) - .response(200) + .eos(), ) + .await; + srv.send_frame(frames::headers(1).response(200)).await; // fill the whole window - .send_frame(frames::data(1, vec![0u8; 16_384])) - .send_frame(frames::data(1, vec![0u8; 16_384])) - .send_frame(frames::data(1, vec![0u8; 16_384])) - .send_frame(frames::data(1, vec![0u8; 16_383])) + srv.send_frame(frames::data(1, vec![0u8; 16_384])).await; + srv.send_frame(frames::data(1, vec![0u8; 16_384])).await; + srv.send_frame(frames::data(1, vec![0u8; 16_384])).await; + srv.send_frame(frames::data(1, vec![0u8; 16_383])).await; // this frame overflows the window! - .send_frame(frames::data(1, vec![0u8; 128]).eos()) + srv.send_frame(frames::data(1, vec![0u8; 128]).eos()).await; // expecting goaway for the conn, not stream - .recv_frame(frames::go_away(0).flow_control()); - // connection is ended by client + srv.recv_frame(frames::go_away(0).flow_control()).await; + // connection is ended by client + }; - let h2 = client::handshake(io).unwrap().and_then(|(mut client, h2)| { + let h2 = async move { + let (mut client, h2) = client::handshake(io).await.unwrap(); let request = Request::builder() .method(Method::GET) .uri("https://http2.akamai.com/") .body(()) .unwrap(); - let req = client - .send_request(request, true) - .unwrap() - .0.unwrap() - .and_then(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - let body = resp.into_parts().1; - body.concat2().then(|res| { - let err = res.unwrap_err(); - assert_eq!( - err.to_string(), - "protocol error: flow-control protocol violated" - ); - Ok::<(), ()>(()) - }) - }); + let req = async move { + let resp = client.send_request(request, true).unwrap().0.await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let body = resp.into_parts().1; + let res = body.try_concat().await; + let err = res.unwrap_err(); + assert_eq!( + err.to_string(), + "protocol error: flow-control protocol violated" + ); + }; // client should see a flow control error - let conn = h2.then(|res| { + let conn = async move { + let res = h2.await; let err = res.unwrap_err(); assert_eq!( err.to_string(), "protocol error: flow-control protocol violated" ); - Ok::<(), ()>(()) - }); - conn.unwrap().join(req) - }); - h2.join(mock).wait().unwrap(); + }; + join(conn, req).await; + }; + join(mock, h2).await; } -#[test] -fn recv_data_overflows_stream_window() { +#[tokio::test] +async fn recv_data_overflows_stream_window() { // this tests for when streams have smaller windows than their connection let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let mock = srv.assert_client_handshake().unwrap() - .ignore_settings() - .recv_frame( + let mock = async move { + let _ = srv.assert_client_handshake().await; + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") - .eos() - ) - .send_frame( - frames::headers(1) - .response(200) + .eos(), ) + .await; + srv.send_frame(frames::headers(1).response(200)).await; // fill the whole window - .send_frame(frames::data(1, vec![0u8; 16_384])) + srv.send_frame(frames::data(1, vec![0u8; 16_384])).await; // this frame overflows the window! - .send_frame(frames::data(1, &[0; 16][..]).eos()) - .recv_frame(frames::reset(1).flow_control()) - .close(); - - let h2 = client::Builder::new() - .initial_window_size(16_384) - .handshake::<_, Bytes>(io) - .unwrap() - .and_then(|(mut client, conn)| { - let request = Request::builder() - .method(Method::GET) - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); - - let req = client - .send_request(request, true) - .unwrap() - .0.unwrap() - .and_then(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - let body = resp.into_parts().1; - body.concat2().then(|res| { - let err = res.unwrap_err(); - assert_eq!( - err.to_string(), - "protocol error: flow-control protocol violated" - ); - Ok::<(), ()>(()) - }) - }); - - conn.unwrap() - .join(req) - .map(|c| (c, client)) - }); - h2.join(mock).wait().unwrap(); -} + srv.send_frame(frames::data(1, &[0; 16][..]).eos()).await; + srv.recv_frame(frames::reset(1).flow_control()).await; + }; + + let h2 = async move { + let (mut client, conn) = client::Builder::new() + .initial_window_size(16_384) + .handshake::<_, Bytes>(io) + .await + .unwrap(); + let request = Request::builder() + .method(Method::GET) + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + let req = async move { + let resp = client.send_request(request, true).unwrap().0.await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let body = resp.into_parts().1; + let res = body.try_concat().await; + let err = res.unwrap_err(); + assert_eq!( + err.to_string(), + "protocol error: flow-control protocol violated" + ); + }; + join(async move { conn.await.unwrap() }, req).await; + }; + join(mock, h2).await; +} #[test] #[ignore] @@ -320,85 +296,93 @@ fn recv_window_update_causes_overflow() { // A received window update causes the window to overflow. } -#[test] -fn stream_error_release_connection_capacity() { +#[tokio::test] +async fn stream_error_release_connection_capacity() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") - .eos() + .eos(), ) + .await; // we're sending the wrong content-length - .send_frame( + srv.send_frame( frames::headers(1) .response(200) - .field("content-length", &*(16_384 * 3).to_string()) + .field("content-length", &*(16_384 * 3).to_string()), ) - .send_frame(frames::data(1, vec![0; 16_384])) - .send_frame(frames::data(1, vec![0; 16_384])) - .send_frame(frames::data(1, vec![0; 10]).eos()) + .await; + srv.send_frame(frames::data(1, vec![0; 16_384])).await; + srv.send_frame(frames::data(1, vec![0; 16_384])).await; + srv.send_frame(frames::data(1, vec![0; 10]).eos()).await; // mismatched content-length is a protocol error - .recv_frame(frames::reset(1).protocol_error()) + srv.recv_frame(frames::reset(1).protocol_error()).await; // but then the capacity should be released automatically - .recv_frame(frames::window_update(0, 16_384 * 2 + 10)) - .close(); + srv.recv_frame(frames::window_update(0, 16_384 * 2 + 10)) + .await; + }; - let client = client::handshake(io).unwrap() - .and_then(|(mut client, conn)| { - let request = Request::builder() - .uri("https://http2.akamai.com/") - .body(()).unwrap(); + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.unwrap(); + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); - let req = client.send_request(request, true) + let req = async { + let resp = client + .send_request(request, true) .unwrap() - .0.expect("response") - .and_then(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - let mut body = resp.into_parts().1; - let mut cap = body.release_capacity().clone(); - let to_release = 16_384 * 2; - let mut should_recv_bytes = to_release; - let mut should_recv_frames = 2; - body - .for_each(move |bytes| { - should_recv_bytes -= bytes.len(); - should_recv_frames -= 1; - if should_recv_bytes == 0 { - assert_eq!(should_recv_bytes, 0); - } - - Ok(()) - }) - .expect_err("body") - .map(move |err| { - assert_eq!( - err.to_string(), - "protocol error: unspecific protocol error detected" - ); - cap.release_capacity(to_release).expect("release_capacity"); - }) - }); - conn.drive(req.expect("response")) - .and_then(|(conn, _)| conn.expect("client")) - .map(|c| (c, client)) - }); - - srv.join(client).wait().unwrap(); + .0 + .await + .expect("response"); + assert_eq!(resp.status(), StatusCode::OK); + let mut body = resp.into_parts().1; + let mut cap = body.release_capacity().clone(); + let to_release = 16_384 * 2; + let mut should_recv_bytes = to_release; + let mut should_recv_frames = 2usize; + + let err = body + .try_for_each(|bytes| { + async move { + should_recv_bytes -= bytes.len(); + should_recv_frames -= 1; + if should_recv_bytes == 0 { + assert_eq!(should_recv_bytes, 0); + } + Ok(()) + } + }) + .await + .expect_err("body"); + assert_eq!( + err.to_string(), + "protocol error: unspecific protocol error detected" + ); + cap.release_capacity(to_release).expect("release_capacity"); + }; + conn.drive(req).await; + conn.await.expect("client"); + }; + + join(srv, client).await; } -#[test] -fn stream_close_by_data_frame_releases_capacity() { +#[tokio::test] +async fn stream_close_by_data_frame_releases_capacity() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); let window_size = frame::DEFAULT_INITIAL_WINDOW_SIZE as usize; - let h2 = client::handshake(io).unwrap().and_then(|(mut client, h2)| { + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); let request = Request::builder() .method(Method::POST) .uri("https://http2.akamai.com/") @@ -443,33 +427,34 @@ fn stream_close_by_data_frame_releases_capacity() { // Drive both streams to prevent the handles from being dropped // (which will send a RST_STREAM) before the connection is closed. - h2.drive(resp1) - .and_then(move |(h2, _)| h2.drive(resp2)) - }) - .unwrap(); - - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame(frames::headers(1).request("POST", "https://http2.akamai.com/")) - .send_frame(frames::headers(1).response(200)) - .recv_frame(frames::headers(3).request("POST", "https://http2.akamai.com/")) - .send_frame(frames::headers(3).response(200)) - .recv_frame(frames::data(1, &b""[..]).eos()) - .recv_frame(frames::data(3, &b"hello"[..]).eos()) - .close(); - - let _ = h2.join(srv).wait().unwrap(); + h2.drive(resp1).await.unwrap(); + h2.drive(resp2).await.unwrap(); + }; + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://http2.akamai.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.recv_frame(frames::headers(3).request("POST", "https://http2.akamai.com/")) + .await; + srv.send_frame(frames::headers(3).response(200)).await; + srv.recv_frame(frames::data(1, &b""[..]).eos()).await; + srv.recv_frame(frames::data(3, &b"hello"[..]).eos()).await; + }; + join(srv, h2).await; } -#[test] -fn stream_close_by_trailers_frame_releases_capacity() { +#[tokio::test] +async fn stream_close_by_trailers_frame_releases_capacity() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); let window_size = frame::DEFAULT_INITIAL_WINDOW_SIZE as usize; - let h2 = client::handshake(io).unwrap().and_then(|(mut client, h2)| { + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); let request = Request::builder() .method(Method::POST) .uri("https://http2.akamai.com/") @@ -514,596 +499,540 @@ fn stream_close_by_trailers_frame_releases_capacity() { // Drive both streams to prevent the handles from being dropped // (which will send a RST_STREAM) before the connection is closed. - h2.drive(resp1) - .and_then(move |(h2, _)| h2.drive(resp2)) - }) - .unwrap(); + h2.drive(resp1).await.unwrap(); + h2.drive(resp2).await.unwrap(); + }; - let srv = srv.assert_client_handshake().unwrap() + let srv = async move { + let settings = srv.assert_client_handshake().await; // Get the first frame - .recv_settings() - .recv_frame( - frames::headers(1) - .request("POST", "https://http2.akamai.com/") - ) - .send_frame(frames::headers(1).response(200)) - .recv_frame( - frames::headers(3) - .request("POST", "https://http2.akamai.com/") - ) - .send_frame(frames::headers(3).response(200)) - .recv_frame(frames::headers(1).eos()) - .recv_frame(frames::data(3, &b"hello"[..]).eos()) - .close(); - - let _ = h2.join(srv).wait().unwrap(); + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://http2.akamai.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.recv_frame(frames::headers(3).request("POST", "https://http2.akamai.com/")) + .await; + srv.send_frame(frames::headers(3).response(200)).await; + srv.recv_frame(frames::headers(1).eos()).await; + srv.recv_frame(frames::data(3, &b"hello"[..]).eos()).await; + }; + join(srv, h2).await; } -#[test] -fn stream_close_by_send_reset_frame_releases_capacity() { +#[tokio::test] +async fn stream_close_by_send_reset_frame_releases_capacity() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") - .eos() + .eos(), ) - .send_frame(frames::headers(1).response(200)) - .send_frame(frames::data(1, vec![0; 16_384])) - .send_frame(frames::data(1, vec![0; 16_384]).eos()) - .recv_frame(frames::window_update(0, 16_384 * 2)) - .recv_frame( + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.send_frame(frames::data(1, vec![0; 16_384])).await; + srv.send_frame(frames::data(1, vec![0; 16_384]).eos()).await; + srv.recv_frame(frames::window_update(0, 16_384 * 2)).await; + srv.recv_frame( frames::headers(3) .request("GET", "https://http2.akamai.com/") - .eos() + .eos(), ) - .send_frame(frames::headers(3).response(200).eos()) - .close(); + .await; + srv.send_frame(frames::headers(3).response(200).eos()).await; + }; - let client = client::handshake(io).expect("client handshake") - .and_then(|(mut client, conn)| { + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.expect("client handshake"); + { let request = Request::builder() .uri("https://http2.akamai.com/") - .body(()).unwrap(); + .body(()) + .unwrap(); let (resp, _) = client.send_request(request, true).unwrap(); - conn.drive(resp.expect("response")).map(move |c| (c, client)) - }) - .and_then(|((conn, _res), mut client)| { - // ^-- ignore the response body + let _res = conn.drive(resp).await; + // ^-- ignore the response body + } + let resp = { let request = Request::builder() .uri("https://http2.akamai.com/") - .body(()).unwrap(); + .body(()) + .unwrap(); let (resp, _) = client.send_request(request, true).unwrap(); - conn.drive(resp.expect("response")) - }) - .and_then(|(conn, _res)| { - conn.expect("client conn") - }); - - srv.join(client).wait().expect("wait"); + drop(client); + resp + }; + let _res = conn.drive(resp).await; + conn.await.expect("client conn"); + }; + + join(srv, client).await; } #[test] #[ignore] fn stream_close_by_recv_reset_frame_releases_capacity() {} -#[test] -fn recv_window_update_on_stream_closed_by_data_frame() { +#[tokio::test] +async fn recv_window_update_on_stream_closed_by_data_frame() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let h2 = client::handshake(io) - .unwrap() - .and_then(|(mut client, h2)| { - let request = Request::builder() - .method(Method::POST) - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); - let (response, stream) = client.send_request(request, false).unwrap(); + let (response, mut stream) = client.send_request(request, false).unwrap(); - // Wait for the response - h2.drive(response.map(|response| (response, stream))) - }) - .and_then(|(h2, (response, mut stream))| { - assert_eq!(response.status(), StatusCode::OK); - - // Send a data frame, this will also close the connection - stream.send_data("hello".into(), true).unwrap(); - - // keep `stream` from being dropped in order to prevent - // it from sending an RST_STREAM frame. - // - // i know this is kind of evil, but it's necessary to - // ensure that the stream is closed by the EOS frame, - // and not by the RST_STREAM. - std::mem::forget(stream); - - // Wait for the connection to close - h2.unwrap() - }); - - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame(frames::headers(1).request("POST", "https://http2.akamai.com/")) - .send_frame(frames::headers(1).response(200)) - .recv_frame(frames::data(1, "hello").eos()) - .send_frame(frames::window_update(1, 5)) - .map(drop); + // Wait for the response + let response = h2.drive(response).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + // Send a data frame, this will also close the connection + stream.send_data("hello".into(), true).unwrap(); - let _ = h2.join(srv).wait().unwrap(); + // keep `stream` from being dropped in order to prevent + // it from sending an RST_STREAM frame. + // + // i know this is kind of evil, but it's necessary to + // ensure that the stream is closed by the EOS frame, + // and not by the RST_STREAM. + std::mem::forget(stream); + + // Wait for the connection to close + h2.await.unwrap(); + }; + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://http2.akamai.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.recv_frame(frames::data(1, "hello").eos()).await; + srv.send_frame(frames::window_update(1, 5)).await; + }; + join(srv, h2).await; } -#[test] -fn reserved_capacity_assigned_in_multi_window_updates() { +#[tokio::test] +async fn reserved_capacity_assigned_in_multi_window_updates() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let h2 = client::handshake(io) - .unwrap() - .and_then(|(mut client, h2)| { - let request = Request::builder() - .method(Method::POST) - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); - let (response, mut stream) = client.send_request(request, false).unwrap(); + let (response, mut stream) = client.send_request(request, false).unwrap(); - // Consume the capacity - let payload = vec![0; frame::DEFAULT_INITIAL_WINDOW_SIZE as usize]; - stream.send_data(payload.into(), false).unwrap(); + // Consume the capacity + let payload = vec![0; frame::DEFAULT_INITIAL_WINDOW_SIZE as usize]; + stream.send_data(payload.into(), false).unwrap(); - // Reserve more data than we want - stream.reserve_capacity(10); + // Reserve more data than we want + stream.reserve_capacity(10); - h2.drive( - util::wait_for_capacity(stream, 5) - .map(|stream| (response, client, stream))) - }) - .and_then(|(h2, (response, client, mut stream))| { - stream.send_data("hello".into(), false).unwrap(); - stream.send_data("world".into(), true).unwrap(); + let mut stream = h2.drive(util::wait_for_capacity(stream, 5)).await; + stream.send_data("hello".into(), false).unwrap(); + stream.send_data("world".into(), true).unwrap(); - h2.drive(response).map(|c| (c, client)) - }) - .and_then(|((h2, response), client)| { - assert_eq!(response.status(), StatusCode::NO_CONTENT); + let response = h2.drive(response).await.unwrap(); + assert_eq!(response.status(), StatusCode::NO_CONTENT); - // Wait for the connection to close - h2.unwrap().map(|c| (c, client)) - }); + // Wait for the connection to close + h2.await.unwrap(); + }; - let srv = srv.assert_client_handshake().unwrap() - .recv_settings() - .recv_frame( - frames::headers(1) - .request("POST", "https://http2.akamai.com/") - ) - .recv_frame(frames::data(1, vec![0u8; 16_384])) - .recv_frame(frames::data(1, vec![0u8; 16_384])) - .recv_frame(frames::data(1, vec![0u8; 16_384])) - .recv_frame(frames::data(1, vec![0u8; 16_383])) - .idle_ms(100) + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://http2.akamai.com/")) + .await; + srv.recv_frame(frames::data(1, vec![0u8; 16_384])).await; + srv.recv_frame(frames::data(1, vec![0u8; 16_384])).await; + srv.recv_frame(frames::data(1, vec![0u8; 16_384])).await; + srv.recv_frame(frames::data(1, vec![0u8; 16_383])).await; + idle_ms(100).await; // Increase the connection window - .send_frame( - frames::window_update(0, 10)) + srv.send_frame(frames::window_update(0, 10)).await; // Incrementally increase the stream window - .send_frame( - frames::window_update(1, 4)) - .idle_ms(50) - .send_frame( - frames::window_update(1, 1)) + srv.send_frame(frames::window_update(1, 4)).await; + idle_ms(50).await; + srv.send_frame(frames::window_update(1, 1)).await; // Receive first chunk - .recv_frame(frames::data(1, "hello")) - .send_frame( - frames::window_update(1, 5)) + srv.recv_frame(frames::data(1, "hello")).await; + srv.send_frame(frames::window_update(1, 5)).await; // Receive second chunk - .recv_frame( - frames::data(1, "world").eos()) - .send_frame( - frames::headers(1) - .response(204) - .eos() - ) + srv.recv_frame(frames::data(1, "world").eos()).await; + srv.send_frame(frames::headers(1).response(204).eos()).await; /* .recv_frame(frames::data(1, "hello").eos()) .send_frame(frames::window_update(1, 5)) */ - .map(drop); - - let _ = h2.join(srv).wait().unwrap(); + }; + join(srv, h2).await; } -#[test] -fn connection_notified_on_released_capacity() { - use crate::futures::sync::oneshot; - use std::sync::mpsc; - use std::thread; +#[tokio::test] +async fn connection_notified_on_released_capacity() { + use futures::channel::mpsc; + use futures::channel::oneshot; let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); // We're going to run the connection on a thread in order to isolate task // notifications. This test is here, in part, to ensure that the connection // receives the appropriate notifications to send out window updates. - let (tx, rx) = mpsc::channel(); + let (tx, mut rx) = mpsc::unbounded(); // Because threading is fun let (settings_tx, settings_rx) = oneshot::channel(); - let th1 = thread::spawn(move || { - srv.assert_client_handshake().unwrap() - .recv_settings() - .map(move |v| { - settings_tx.send(()).unwrap(); - v - }) - // Get the first request - .recv_frame( - frames::headers(1) - .request("GET", "https://example.com/a") - .eos()) - // Get the second request - .recv_frame( - frames::headers(3) - .request("GET", "https://example.com/b") - .eos()) - // Send the first response - .send_frame(frames::headers(1).response(200)) - // Send the second response - .send_frame(frames::headers(3).response(200)) - - // Fill the connection window - .send_frame(frames::data(1, vec![0u8; 16_384]).eos()) - .idle_ms(100) - .send_frame(frames::data(3, vec![0u8; 16_384]).eos()) - - // The window update is sent - .recv_frame(frames::window_update(0, 16_384)) - .map(drop) - .wait().unwrap(); - }); - - - let th2 = thread::spawn(move || { - let (mut client, h2) = client::handshake(io).wait().unwrap(); + let (th1_tx, th1_rx) = oneshot::channel(); - let (h2, _) = h2.drive(settings_rx).wait().unwrap(); + tokio::spawn(async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + settings_tx.send(()).unwrap(); + // Get the first request + srv.recv_frame( + frames::headers(1) + .request("GET", "https://example.com/a") + .eos(), + ) + .await; + // Get the second request + srv.recv_frame( + frames::headers(3) + .request("GET", "https://example.com/b") + .eos(), + ) + .await; + // Send the first response + srv.send_frame(frames::headers(1).response(200)).await; + // Send the second response + srv.send_frame(frames::headers(3).response(200)).await; + + // Fill the connection window + srv.send_frame(frames::data(1, vec![0u8; 16_384]).eos()) + .await; + idle_ms(100).await; + srv.send_frame(frames::data(3, vec![0u8; 16_384]).eos()) + .await; + + // The window update is sent + srv.recv_frame(frames::window_update(0, 16_384)).await; + + th1_tx.send(()).unwrap(); + }); - let request = Request::get("https://example.com/a").body(()).unwrap(); + let (th2_tx, th2_rx) = oneshot::channel(); - tx.send(client.send_request(request, true).unwrap()) - .unwrap(); + let (mut client, mut h2) = client::handshake(io).await.unwrap(); - let request = Request::get("https://example.com/b").body(()).unwrap(); + h2.drive(settings_rx).await.unwrap(); + let request = Request::get("https://example.com/a").body(()).unwrap(); + tx.unbounded_send(client.send_request(request, true).unwrap().0) + .unwrap(); - tx.send(client.send_request(request, true).unwrap()) - .unwrap(); + let request = Request::get("https://example.com/b").body(()).unwrap(); + tx.unbounded_send(client.send_request(request, true).unwrap().0) + .unwrap(); + tokio::spawn(async move { // Run the connection to completion - h2.wait().unwrap(); + h2.await.unwrap(); + + th2_tx.send(()).unwrap(); + drop(client); }); // Get the two requests - let (a, _) = rx.recv().unwrap(); - let (b, _) = rx.recv().unwrap(); + let a = rx.next().await.unwrap(); + let b = rx.next().await.unwrap(); // Get the first response - let response = a.wait().unwrap(); + let response = a.await.unwrap(); assert_eq!(response.status(), StatusCode::OK); - let (_, a) = response.into_parts(); + let (_, mut a) = response.into_parts(); // Get the next chunk - let (chunk, mut a) = a.into_future().wait().unwrap(); + let chunk = a.next().await.unwrap(); assert_eq!(16_384, chunk.unwrap().len()); // Get the second response - let response = b.wait().unwrap(); + let response = b.await.unwrap(); assert_eq!(response.status(), StatusCode::OK); - let (_, b) = response.into_parts(); + let (_, mut b) = response.into_parts(); // Get the next chunk - let (chunk, b) = b.into_future().wait().unwrap(); + let chunk = b.next().await.unwrap(); assert_eq!(16_384, chunk.unwrap().len()); // Wait a bit - thread::sleep(Duration::from_millis(100)); + idle_ms(100).await; // Release the capacity a.release_capacity().release_capacity(16_384).unwrap(); - th1.join().unwrap(); - th2.join().unwrap(); + th1_rx.await.unwrap(); + th2_rx.await.unwrap(); // Explicitly drop this after the joins so that the capacity doesn't get // implicitly released before. drop(b); } -#[test] -fn recv_settings_removes_available_capacity() { +#[tokio::test] +async fn recv_settings_removes_available_capacity() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); let mut settings = frame::Settings::default(); settings.set_initial_window_size(Some(0)); - let srv = srv.assert_client_handshake_with_settings(settings).unwrap() - .recv_settings() - .recv_frame( - frames::headers(1) - .request("POST", "https://http2.akamai.com/") - ) - .idle_ms(100) - .send_frame(frames::window_update(0, 11)) - .send_frame(frames::window_update(1, 11)) - .recv_frame(frames::data(1, "hello world").eos()) - .send_frame( - frames::headers(1) - .response(204) - .eos() - ) - .close(); - - - let h2 = client::handshake(io).unwrap() - .and_then(|(mut client, h2)| { - let request = Request::builder() - .method(Method::POST) - .uri("https://http2.akamai.com/") - .body(()).unwrap(); - - let (response, mut stream) = client.send_request(request, false).unwrap(); + let srv = async move { + let settings = srv.assert_client_handshake_with_settings(settings).await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://http2.akamai.com/")) + .await; + idle_ms(100).await; + srv.send_frame(frames::window_update(0, 11)).await; + srv.send_frame(frames::window_update(1, 11)).await; + srv.recv_frame(frames::data(1, "hello world").eos()).await; + srv.send_frame(frames::headers(1).response(204).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); - stream.reserve_capacity(11); + let (response, mut stream) = client.send_request(request, false).unwrap(); - h2.drive(util::wait_for_capacity(stream, 11).map(|s| (response, client, s))) - }) - .and_then(|(h2, (response, client, mut stream))| { - assert_eq!(stream.capacity(), 11); + stream.reserve_capacity(11); - stream.send_data("hello world".into(), true).unwrap(); + let mut stream = h2.drive(util::wait_for_capacity(stream, 11)).await; + assert_eq!(stream.capacity(), 11); - h2.drive(response).map(|c| (c, client)) - }) - .and_then(|((h2, response), client)| { - assert_eq!(response.status(), StatusCode::NO_CONTENT); + stream.send_data("hello world".into(), true).unwrap(); - // Wait for the connection to close - // Hold on to the `client` handle to avoid sending a GO_AWAY frame. - h2.unwrap().map(|c| (c, client)) - }); + let response = h2.drive(response).await.unwrap(); + assert_eq!(response.status(), StatusCode::NO_CONTENT); - let _ = h2.join(srv) - .wait().unwrap(); + // Wait for the connection to close + // Hold on to the `client` handle to avoid sending a GO_AWAY frame. + h2.await.unwrap(); + }; + join(srv, h2).await; } -#[test] -fn recv_settings_keeps_assigned_capacity() { +#[tokio::test] +async fn recv_settings_keeps_assigned_capacity() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); + + let (sent_settings, sent_settings_rx) = futures::channel::oneshot::channel(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://http2.akamai.com/")) + .await; + srv.send_frame(frames::settings().initial_window_size(64)) + .await; + srv.recv_frame(frames::settings_ack()).await; + sent_settings.send(()).unwrap(); + srv.recv_frame(frames::data(1, "hello world").eos()).await; + srv.send_frame(frames::headers(1).response(204).eos()).await; + }; + + let h2 = async move { + let (mut client, h2) = client::handshake(io).await.unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); - let (sent_settings, sent_settings_rx) = futures::sync::oneshot::channel(); + let (response, mut stream) = client.send_request(request, false).unwrap(); - let srv = srv.assert_client_handshake().unwrap() - .recv_settings() - .recv_frame( - frames::headers(1) - .request("POST", "https://http2.akamai.com/") - ) - .send_frame(frames::settings().initial_window_size(64)) - .recv_frame(frames::settings_ack()) - .then_notify(sent_settings) - .recv_frame(frames::data(1, "hello world").eos()) - .send_frame( - frames::headers(1) - .response(204) - .eos() - ) - .close(); + stream.reserve_capacity(11); + let f = async move { + let mut stream = util::wait_for_capacity(stream, 11).await; + sent_settings_rx.await.expect("rx"); + stream.send_data("hello world".into(), true).unwrap(); + let resp = response.await.expect("response"); + assert_eq!(resp.status(), StatusCode::NO_CONTENT); + }; + join(async move { h2.await.expect("h2") }, f).await; + }; - let h2 = client::handshake(io).unwrap() - .and_then(move |(mut client, h2)| { - let request = Request::builder() - .method(Method::POST) - .uri("https://http2.akamai.com/") - .body(()).unwrap(); - - let (response, mut stream) = client.send_request(request, false).unwrap(); - - stream.reserve_capacity(11); - - h2.expect("h2") - .join( - util::wait_for_capacity(stream, 11) - .and_then(|mut stream| { - sent_settings_rx.expect("rx") - .and_then(move |()| { - stream.send_data("hello world".into(), true).unwrap(); - response.expect("response") - }) - .and_then(move |resp| { - assert_eq!(resp.status(), StatusCode::NO_CONTENT); - Ok(client) - }) - }) - ) - }); - - let _ = h2.join(srv) - .wait().unwrap(); + join(srv, h2).await; } -#[test] -fn recv_no_init_window_then_receive_some_init_window() { +#[tokio::test] +async fn recv_no_init_window_then_receive_some_init_window() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); let mut settings = frame::Settings::default(); settings.set_initial_window_size(Some(0)); - let srv = srv.assert_client_handshake_with_settings(settings).unwrap() - .recv_settings() - .recv_frame( - frames::headers(1) - .request("POST", "https://http2.akamai.com/") - ) - .idle_ms(100) - .send_frame(frames::settings().initial_window_size(10)) - .recv_frame(frames::settings_ack()) - .recv_frame(frames::data(1, "hello worl")) - .idle_ms(100) - .send_frame(frames::settings().initial_window_size(11)) - .recv_frame(frames::settings_ack()) - .recv_frame(frames::data(1, "d").eos()) - .send_frame( - frames::headers(1) - .response(204) - .eos() - ) - .close(); - - - let h2 = client::handshake(io).unwrap() - .and_then(|(mut client, h2)| { - let request = Request::builder() - .method(Method::POST) - .uri("https://http2.akamai.com/") - .body(()).unwrap(); - - let (response, mut stream) = client.send_request(request, false).unwrap(); + let srv = async move { + let settings = srv.assert_client_handshake_with_settings(settings).await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://http2.akamai.com/")) + .await; + idle_ms(100).await; + srv.send_frame(frames::settings().initial_window_size(10)) + .await; + srv.recv_frame(frames::settings_ack()).await; + srv.recv_frame(frames::data(1, "hello worl")).await; + idle_ms(100).await; + srv.send_frame(frames::settings().initial_window_size(11)) + .await; + srv.recv_frame(frames::settings_ack()).await; + srv.recv_frame(frames::data(1, "d").eos()).await; + srv.send_frame(frames::headers(1).response(204).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); - stream.reserve_capacity(11); + let (response, mut stream) = client.send_request(request, false).unwrap(); - h2.drive(util::wait_for_capacity(stream, 11).map(|s| (response, client, s))) - }) - .and_then(|(h2, (response, client, mut stream))| { - assert_eq!(stream.capacity(), 11); + stream.reserve_capacity(11); - stream.send_data("hello world".into(), true).unwrap(); + let mut stream = h2.drive(util::wait_for_capacity(stream, 11)).await; + assert_eq!(stream.capacity(), 11); - h2.drive(response).map(|c| (c, client)) - }) - .and_then(|((h2, response), client)| { - assert_eq!(response.status(), StatusCode::NO_CONTENT); + stream.send_data("hello world".into(), true).unwrap(); - // Wait for the connection to close - // Hold on to the `client` handle to avoid sending a GO_AWAY frame. - h2.unwrap().map(|c| (c, client)) - }); + let response = h2.drive(response).await.unwrap(); + assert_eq!(response.status(), StatusCode::NO_CONTENT); - let _ = h2.join(srv) - .wait().unwrap(); + // Wait for the connection to close + // Hold on to the `client` handle to avoid sending a GO_AWAY frame. + h2.await.unwrap(); + }; + join(srv, h2).await; } -#[test] -fn settings_lowered_capacity_returns_capacity_to_connection() { - use std::sync::mpsc; - use std::thread; +#[tokio::test] +async fn settings_lowered_capacity_returns_capacity_to_connection() { + use futures::channel::oneshot; + use futures::future::{select, Either}; + use std::time::Instant; + use tokio::timer::Delay; let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - let (tx1, rx1) = mpsc::channel(); - let (tx2, rx2) = mpsc::channel(); + let (io, mut srv) = mock::new(); + let (tx1, rx1) = oneshot::channel(); + let (tx2, rx2) = oneshot::channel(); let window_size = frame::DEFAULT_INITIAL_WINDOW_SIZE as usize; + let (th1_tx, th1_rx) = oneshot::channel(); // Spawn the server on a thread - let th1 = thread::spawn(move || { - let srv = srv.assert_client_handshake().unwrap() - .recv_settings() - .wait().unwrap(); - + tokio::spawn(async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); tx1.send(()).unwrap(); - - let srv = Ok::<_, ()>(srv).into_future() - .recv_frame( - frames::headers(1) - .request("POST", "https://example.com/one") - ) - .recv_frame( - frames::headers(3) - .request("POST", "https://example.com/two") - ) - .idle_ms(200) - // Remove all capacity from streams - .send_frame(frames::settings().initial_window_size(0)) - .recv_frame(frames::settings_ack()) - - // Let stream 3 make progress - .send_frame(frames::window_update(3, 11)) - .recv_frame(frames::data(3, "hello world").eos()) - .wait().unwrap(); - + srv.recv_frame(frames::headers(1).request("POST", "https://example.com/one")) + .await; + srv.recv_frame(frames::headers(3).request("POST", "https://example.com/two")) + .await; + idle_ms(200).await; + // Remove all capacity from streams + srv.send_frame(frames::settings().initial_window_size(0)) + .await; + srv.recv_frame(frames::settings_ack()).await; + + // Let stream 3 make progress + srv.send_frame(frames::window_update(3, 11)).await; + srv.recv_frame(frames::data(3, "hello world").eos()).await; // Wait to get notified // // A timeout is used here to avoid blocking forever if there is a // failure - let _ = rx2.recv_timeout(Duration::from_secs(5)).unwrap(); + let result = select(rx2, Delay::new(Instant::now() + Duration::from_secs(5))).await; + if let Either::Right((_, _)) = result { + panic!("Timed out"); + } - thread::sleep(Duration::from_millis(500)); + idle_ms(500).await; // Reset initial window size - Ok::<_, ()>(srv).into_future() - .send_frame(frames::settings().initial_window_size(window_size as u32)) - .recv_frame(frames::settings_ack()) - - // Get data from first stream - .recv_frame(frames::data(1, "hello world").eos()) - - // Send responses - .send_frame( - frames::headers(1) - .response(204) - .eos() - ) - .send_frame( - frames::headers(3) - .response(204) - .eos() - ) - .close() - .wait().unwrap(); + srv.send_frame(frames::settings().initial_window_size(window_size as u32)) + .await; + srv.recv_frame(frames::settings_ack()).await; + + // Get data from first stream + srv.recv_frame(frames::data(1, "hello world").eos()).await; + + // Send responses + srv.send_frame(frames::headers(1).response(204).eos()).await; + srv.send_frame(frames::headers(3).response(204).eos()).await; + drop(srv); + th1_tx.send(()).unwrap(); }); - let (mut client, h2) = client::handshake(io).unwrap() - .wait().unwrap(); + let (mut client, h2) = client::handshake(io).await.unwrap(); + let (th2_tx, th2_rx) = oneshot::channel(); // Drive client connection - let th2 = thread::spawn(move || { - h2.wait().unwrap(); + tokio::spawn(async move { + h2.await.unwrap(); + th2_tx.send(()).unwrap(); }); // Wait for server handshake to complete. - rx1.recv_timeout(Duration::from_secs(5)).unwrap(); + let result = select(rx1, Delay::new(Instant::now() + Duration::from_secs(5))).await; + if let Either::Right((_, _)) = result { + panic!("Timed out"); + } - let request = Request::post("https://example.com/one") - .body(()).unwrap(); + let request = Request::post("https://example.com/one").body(()).unwrap(); let (resp1, mut stream1) = client.send_request(request, false).unwrap(); - let request = Request::post("https://example.com/two") - .body(()).unwrap(); + let request = Request::post("https://example.com/two").body(()).unwrap(); let (resp2, mut stream2) = client.send_request(request, false).unwrap(); // Reserve capacity for stream one, this will consume all connection level // capacity stream1.reserve_capacity(window_size); - let stream1 = util::wait_for_capacity(stream1, window_size).wait().unwrap(); + let stream1 = util::wait_for_capacity(stream1, window_size).await; // Now, wait for capacity on the other stream stream2.reserve_capacity(11); - let mut stream2 = util::wait_for_capacity(stream2, 11).wait().unwrap(); + let mut stream2 = util::wait_for_capacity(stream2, 11).await; // Send data on stream 2 stream2.send_data("hello world".into(), true).unwrap(); @@ -1111,276 +1040,256 @@ fn settings_lowered_capacity_returns_capacity_to_connection() { tx2.send(()).unwrap(); // Wait for capacity on stream 1 - let mut stream1 = util::wait_for_capacity(stream1, 11).wait().unwrap(); + let mut stream1 = util::wait_for_capacity(stream1, 11).await; stream1.send_data("hello world".into(), true).unwrap(); // Wait for responses.. - let resp = resp1.wait().unwrap(); + let resp = resp1.await.unwrap(); assert_eq!(resp.status(), StatusCode::NO_CONTENT); - let resp = resp2.wait().unwrap(); + let resp = resp2.await.unwrap(); assert_eq!(resp.status(), StatusCode::NO_CONTENT); - th1.join().unwrap(); - th2.join().unwrap(); + th1_rx.await.unwrap(); + th2_rx.await.unwrap(); } -#[test] -fn client_increase_target_window_size() { +#[tokio::test] +async fn client_increase_target_window_size() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame(frames::window_update(0, (2 << 20) - 65_535)) - .close(); - - - let client = client::handshake(io).unwrap() - .and_then(|(_client, mut conn)| { - conn.set_target_window_size(2 << 20); - - conn.unwrap().map(|c| (c, _client)) - }); - - srv.join(client).wait().unwrap(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::window_update(0, (2 << 20) - 65_535)) + .await; + }; + + let client = async move { + let (_client, mut conn) = client::handshake(io).await.unwrap(); + conn.set_target_window_size(2 << 20); + conn.await.unwrap(); + }; + join(srv, client).await; } -#[test] -fn increase_target_window_size_after_using_some() { +#[tokio::test] +async fn increase_target_window_size_after_using_some() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") - .eos() + .eos(), ) - .send_frame(frames::headers(1).response(200)) - .send_frame(frames::data(1, vec![0; 16_384]).eos()) - .recv_frame(frames::window_update(0, (2 << 20) - 65_535)) - .close(); + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.send_frame(frames::data(1, vec![0; 16_384]).eos()).await; + srv.recv_frame(frames::window_update(0, (2 << 20) - 65_535)) + .await; + }; + + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.unwrap(); + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); - let client = client::handshake(io).unwrap() - .and_then(|(mut client, conn)| { - let request = Request::builder() - .uri("https://http2.akamai.com/") - .body(()).unwrap(); + let res = client.send_request(request, true).unwrap().0; - let res = client.send_request(request, true).unwrap().0; + let res = conn.drive(res).await.unwrap(); + conn.set_target_window_size(2 << 20); + // drive an empty future to allow the WINDOW_UPDATE + // to go out while the response capacity is still in use. + conn.drive(yield_once()).await; + let _res = conn.drive(res.into_body().try_concat()).await; + conn.await.expect("client"); + }; - conn.drive(res) - }) - .and_then(|(mut conn, res)| { - conn.set_target_window_size(2 << 20); - // drive an empty future to allow the WINDOW_UPDATE - // to go out while the response capacity is still in use. - let mut yielded = false; - conn.drive(futures::future::poll_fn(move || { - if yielded { - Ok::<_, ()>(().into()) - } else { - yielded = true; - futures::task::current().notify(); - Ok(futures::Async::NotReady) - } - })) - .map(move |(c, _)| (c, res)) - }) - .and_then(|(conn, res)| { - conn.drive(res.into_body().concat2()) - .and_then(|(c, _)| c.expect("client")) - }); - - srv.join(client).wait().unwrap(); + join(srv, client).await; } -#[test] -fn decrease_target_window_size() { +#[tokio::test] +async fn decrease_target_window_size() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") - .eos() + .eos(), ) - .send_frame(frames::headers(1).response(200)) - .send_frame(frames::data(1, vec![0; 16_384])) - .send_frame(frames::data(1, vec![0; 16_384])) - .send_frame(frames::data(1, vec![0; 16_384])) - .send_frame(frames::data(1, vec![0; 16_383]).eos()) - .recv_frame(frames::window_update(0, 16_384)) - .close(); - - let client = client::handshake(io).unwrap() - .and_then(|(mut client, mut conn)| { - conn.set_target_window_size(16_384 * 2); + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.send_frame(frames::data(1, vec![0; 16_384])).await; + srv.send_frame(frames::data(1, vec![0; 16_384])).await; + srv.send_frame(frames::data(1, vec![0; 16_384])).await; + srv.send_frame(frames::data(1, vec![0; 16_383]).eos()).await; + srv.recv_frame(frames::window_update(0, 16_384)).await; + }; + + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.unwrap(); + conn.set_target_window_size(16_384 * 2); - let request = Request::builder() - .uri("https://http2.akamai.com/") - .body(()).unwrap(); - let (resp, _) = client.send_request(request, true).unwrap(); - conn.drive(resp.expect("response")).map(|c| (c, client)) - }) - .and_then(|((mut conn, res), client)| { - conn.set_target_window_size(16_384); - let mut body = res.into_parts().1; - let mut cap = body.release_capacity().clone(); - - conn.drive(body.concat2().expect("concat")) - .map(|c| (c, client)) - .and_then(move |((conn, bytes), client)| { - assert_eq!(bytes.len(), 65_535); - cap.release_capacity(bytes.len()).unwrap(); - conn.expect("conn").map(|c| (c, client)) - }) - }); - - srv.join(client).wait().unwrap(); + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + let (resp, _) = client.send_request(request, true).unwrap(); + let res = conn.drive(resp).await.expect("response"); + conn.set_target_window_size(16_384); + let mut body = res.into_parts().1; + let mut cap = body.release_capacity().clone(); + + let bytes = conn.drive(body.try_concat()).await.expect("concat"); + assert_eq!(bytes.len(), 65_535); + cap.release_capacity(bytes.len()).unwrap(); + conn.await.expect("conn"); + }; + + join(srv, client).await; } -#[test] -fn server_target_window_size() { +#[tokio::test] +async fn server_target_window_size() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); - - let client = client.assert_server_handshake() - .unwrap() - .recv_settings() - .recv_frame(frames::window_update(0, (2 << 20) - 65_535)) - .close(); - - let srv = server::handshake(io).unwrap() - .and_then(|mut conn| { - conn.set_target_window_size(2 << 20); - conn.into_future().unwrap() - }); - - srv.join(client).wait().unwrap(); + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .recv_frame(frames::window_update(0, (2 << 20) - 65_535)) + .await; + }; + let srv = async move { + let mut conn = server::handshake(io).await.unwrap(); + conn.set_target_window_size(2 << 20); + conn.next().await; + }; + + join(srv, client).await; } -#[test] -fn recv_settings_increase_window_size_after_using_some() { +#[tokio::test] +async fn recv_settings_increase_window_size_after_using_some() { // See https://github.com/hyperium/h2/issues/208 let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); let new_win_size = 16_384 * 4; // 1 bigger than default - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( - frames::headers(1) - .request("POST", "https://http2.akamai.com/") - ) - .recv_frame(frames::data(1, vec![0; 16_384])) - .recv_frame(frames::data(1, vec![0; 16_384])) - .recv_frame(frames::data(1, vec![0; 16_384])) - .recv_frame(frames::data(1, vec![0; 16_383])) - .send_frame( - frames::settings() - .initial_window_size(new_win_size as u32) - ) - .recv_frame(frames::settings_ack()) - .send_frame(frames::window_update(0, 1)) - .recv_frame(frames::data(1, vec![0; 1]).eos()) - .send_frame(frames::headers(1).response(200).eos()) - .close(); - - let client = client::handshake(io).unwrap() - .and_then(|(mut client, conn)| { - let request = Request::builder() - .method("POST") - .uri("https://http2.akamai.com/") - .body(()).unwrap(); - let (resp, mut req_body) = client.send_request(request, false).unwrap(); - req_body.send_data(vec![0; new_win_size].into(), true).unwrap(); - conn.drive(resp.expect("response")).map(|c| (c, client)) - }) - .and_then(|((conn, _res), client)| { - conn.expect("client").map(|c| (c, client)) - }); + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://http2.akamai.com/")) + .await; + srv.recv_frame(frames::data(1, vec![0; 16_384])).await; + srv.recv_frame(frames::data(1, vec![0; 16_384])).await; + srv.recv_frame(frames::data(1, vec![0; 16_384])).await; + srv.recv_frame(frames::data(1, vec![0; 16_383])).await; + srv.send_frame(frames::settings().initial_window_size(new_win_size as u32)) + .await; + srv.recv_frame(frames::settings_ack()).await; + srv.send_frame(frames::window_update(0, 1)).await; + srv.recv_frame(frames::data(1, vec![0; 1]).eos()).await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; + + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.unwrap(); + let request = Request::builder() + .method("POST") + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + let (resp, mut req_body) = client.send_request(request, false).unwrap(); + req_body + .send_data(vec![0; new_win_size].into(), true) + .unwrap(); + let _res = conn.drive(resp).await.expect("response"); + conn.await.expect("client"); + }; - srv.join(client).wait().unwrap(); + join(srv, client).await; } -#[test] -fn reserve_capacity_after_peer_closes() { +#[tokio::test] +async fn reserve_capacity_after_peer_closes() { // See https://github.com/hyperium/h2/issues/300 let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( - frames::headers(1) - .request("POST", "https://http2.akamai.com/") - ) + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://http2.akamai.com/")) + .await; // close connection suddenly - .close(); + }; - let client = client::handshake(io).unwrap() - .and_then(|(mut client, conn)| { - let request = Request::builder() - .method("POST") - .uri("https://http2.akamai.com/") - .body(()).unwrap(); - let (resp, req_body) = client.send_request(request, false).unwrap(); - conn.drive(resp.then(move |result| { - assert!(result.is_err()); - Ok::<_, ()>(req_body) - })) + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.unwrap(); + let request = Request::builder() + .method("POST") + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + let (resp, mut req_body) = client.send_request(request, false).unwrap(); + conn.drive(async move { + let result = resp.await; + assert!(result.is_err()); }) - .and_then(|(conn, mut req_body)| { - // As stated in #300, this would panic because the connection - // had already been closed. - req_body.reserve_capacity(1); - conn.expect("client") - }); - - srv.join(client).wait().expect("wait"); + .await; + // As stated in #300, this would panic because the connection + // had already been closed. + req_body.reserve_capacity(1); + conn.await.expect("client"); + }; + + join(srv, client).await; } -#[test] -fn reset_stream_waiting_for_capacity() { +#[tokio::test] +async fn reset_stream_waiting_for_capacity() { // This tests that receiving a reset on a stream that has some available // connection-level window reassigns that window to another stream. let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - - let srv = srv - .assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame(frames::headers(1).request("GET", "http://example.com/")) - .recv_frame(frames::headers(3).request("GET", "http://example.com/")) - .recv_frame(frames::headers(5).request("GET", "http://example.com/")) - .recv_frame(frames::data(1, vec![0; 16384])) - .recv_frame(frames::data(1, vec![0; 16384])) - .recv_frame(frames::data(1, vec![0; 16384])) - .recv_frame(frames::data(1, vec![0; 16383]).eos()) - .send_frame(frames::headers(1).response(200)) + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("GET", "http://example.com/")) + .await; + srv.recv_frame(frames::headers(3).request("GET", "http://example.com/")) + .await; + srv.recv_frame(frames::headers(5).request("GET", "http://example.com/")) + .await; + srv.recv_frame(frames::data(1, vec![0; 16384])).await; + srv.recv_frame(frames::data(1, vec![0; 16384])).await; + srv.recv_frame(frames::data(1, vec![0; 16384])).await; + srv.recv_frame(frames::data(1, vec![0; 16383]).eos()).await; + srv.send_frame(frames::headers(1).response(200)).await; // Assign enough connection window for stream 3... - .send_frame(frames::window_update(0, 1)) + srv.send_frame(frames::window_update(0, 1)).await; // but then reset it. - .send_frame(frames::reset(3)) + srv.send_frame(frames::reset(3)).await; // 5 should use that window instead. - .recv_frame(frames::data(5, vec![0; 1]).eos()) - .send_frame(frames::headers(5).response(200)) - .close() - ; - + srv.recv_frame(frames::data(5, vec![0; 1]).eos()).await; + srv.send_frame(frames::headers(5).response(200)).await; + }; fn request() -> Request<()> { Request::builder() .uri("http://example.com/") @@ -1388,87 +1297,78 @@ fn reset_stream_waiting_for_capacity() { .unwrap() } - let client = client::Builder::new() - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(move |(mut client, conn)| { - let (req1, mut send1) = client.send_request( - request(), false).unwrap(); - let (req2, mut send2) = client.send_request( - request(), false).unwrap(); - let (req3, mut send3) = client.send_request( - request(), false).unwrap(); - // Use up the connection window. - send1.send_data(vec![0; 65535].into(), true).unwrap(); - // Queue up for more connection window. - send2.send_data(vec![0; 1].into(), true).unwrap(); - // .. and even more. - send3.send_data(vec![0; 1].into(), true).unwrap(); - conn.expect("h2") - .join(req1.expect("req1")) - .join(req2.then(|r| Ok(r.unwrap_err()))) - .join(req3.expect("req3")) - }); - - - client.join(srv).wait().unwrap(); -} + let client = async move { + let (mut client, conn) = client::Builder::new() + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + let (req1, mut send1) = client.send_request(request(), false).unwrap(); + let (req2, mut send2) = client.send_request(request(), false).unwrap(); + let (req3, mut send3) = client.send_request(request(), false).unwrap(); + // Use up the connection window. + send1.send_data(vec![0; 65535].into(), true).unwrap(); + // Queue up for more connection window. + send2.send_data(vec![0; 1].into(), true).unwrap(); + // .. and even more. + send3.send_data(vec![0; 1].into(), true).unwrap(); + join4( + async move { conn.await.expect("h2") }, + async move { req1.await.expect("req1") }, + async move { req2.await.unwrap_err() }, + async move { req3.await.expect("req3") }, + ) + .await; + }; + join(srv, client).await; +} -#[test] -fn data_padding() { +#[tokio::test] +async fn data_padding() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); let mut body = Vec::new(); body.push(5); body.extend_from_slice(&[b'z'; 100][..]); body.extend_from_slice(&[b'0'; 5][..]); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "http://example.com/") - .eos() + .eos(), ) - .send_frame( + .await; + srv.send_frame( frames::headers(1) .response(200) - .field("content-length", 100) - ) - .send_frame( - frames::data(1, body) - .padded() - .eos() + .field("content-length", 100), ) - .close(); - - let h2 = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let request = Request::builder() - .method(Method::GET) - .uri("http://example.com/") - .body(()) - .unwrap(); + .await; + srv.send_frame(frames::data(1, body).padded().eos()).await; + }; + let h2 = async move { + let (mut client, conn) = client::handshake(io).await.expect("handshake"); + let request = Request::builder() + .method(Method::GET) + .uri("http://example.com/") + .body(()) + .unwrap(); - // first request is allowed - let (response, _) = client.send_request(request, true).unwrap(); - let fut = response - .and_then(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - let body = resp.into_body(); - body.concat2() - }) - .map(|bytes| { - assert_eq!(bytes.len(), 100); - }); - conn - .expect("client") - .join(fut.expect("response")) - }); - - h2.join(srv).wait().expect("wait"); + // first request is allowed + let (response, _) = client.send_request(request, true).unwrap(); + let fut = async move { + let resp = response.await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let body = resp.into_body(); + let bytes = body.try_concat().await.unwrap(); + assert_eq!(bytes.len(), 100); + }; + join(async move { conn.await.expect("client") }, fut).await; + }; + + join(srv, h2).await; } diff --git a/tests/h2-tests/tests/hammer.rs b/tests/h2-tests/tests/hammer.rs index b672c9aec..d3a605209 100644 --- a/tests/h2-tests/tests/hammer.rs +++ b/tests/h2-tests/tests/hammer.rs @@ -1,13 +1,26 @@ -use h2_support::prelude::*; -use futures::{Async, Poll}; +#![feature(async_await)] +use futures::{ready, FutureExt, StreamExt, TryFutureExt}; +use h2_support::prelude::*; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use std::io; +use std::{ + net::SocketAddr, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + thread, +}; use tokio::net::{TcpListener, TcpStream}; -use std::{net::SocketAddr, thread, sync::{atomic::{AtomicUsize, Ordering}, Arc}}; struct Server { addr: SocketAddr, reqs: Arc, - join: Option>, + _join: Option>, } impl Server { @@ -23,32 +36,27 @@ impl Server { let reqs = Arc::new(AtomicUsize::new(0)); let reqs2 = reqs.clone(); let join = thread::spawn(move || { - let server = listener.incoming().for_each(move |socket| { - let reqs = reqs2.clone(); - let mk_data = mk_data.clone(); - let connection = server::handshake(socket) - .and_then(move |conn| { - conn.for_each(move |(_, mut respond)| { - reqs.fetch_add(1, Ordering::Release); - let response = Response::builder().status(StatusCode::OK).body(()).unwrap(); - let mut send = respond.send_response(response, false)?; - send.send_data(mk_data(), true).map(|_|()) - }) - }) - .map_err(|e| eprintln!("serve conn error: {:?}", e)); - - tokio::spawn(Box::new(connection)); - Ok(()) - }) - .map_err(|e| eprintln!("serve error: {:?}", e)); + let server = async move { + let mut incoming = listener.incoming(); + while let Some(socket) = incoming.next().await { + let reqs = reqs2.clone(); + let mk_data = mk_data.clone(); + tokio::spawn(async move { + if let Err(e) = handle_request(socket, reqs, mk_data).await { + eprintln!("serve conn error: {:?}", e) + } + }); + } + }; - tokio::run(server); + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(server); }); Self { addr, - join: Some(join), - reqs + _join: Some(join), + reqs, } } @@ -61,6 +69,25 @@ impl Server { } } +async fn handle_request( + socket: io::Result, + reqs: Arc, + mk_data: Arc, +) -> Result<(), Box> +where + F: Fn() -> Bytes, + F: Send + Sync + 'static, +{ + let mut conn = server::handshake(socket?).await?; + while let Some(result) = conn.next().await { + let (_, mut respond) = result?; + reqs.fetch_add(1, Ordering::Release); + let response = Response::builder().status(StatusCode::OK).body(()).unwrap(); + let mut send = respond.send_response(response, false)?; + send.send_data(mk_data(), true)?; + } + Ok(()) +} struct Process { body: RecvStream, @@ -68,30 +95,25 @@ struct Process { } impl Future for Process { - type Item = (); - type Error = h2::Error; + type Output = Result<(), h2::Error>; - fn poll(&mut self) -> Poll<(), h2::Error> { + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { if self.trailers { - return match self.body.poll_trailers()? { - Async::NotReady => Ok(Async::NotReady), - Async::Ready(_) => Ok(().into()), - }; + ready!(self.body.poll_trailers(cx)); + return Poll::Ready(Ok(())); } else { - match self.body.poll()? { - Async::NotReady => return Ok(Async::NotReady), - Async::Ready(None) => { + match ready!(Pin::new(&mut self.body).poll_next(cx)) { + None => { self.trailers = true; - }, - _ => {}, + } + _ => {} } } } } } - #[test] fn hammer_client_concurrency() { // This reproduces issue #326. @@ -106,10 +128,12 @@ fn hammer_client_concurrency() { print!("sending {}", i); let rsps = rsps.clone(); let tcp = TcpStream::connect(&addr); - let tcp = tcp.then(|res| { - let tcp = res.unwrap(); - client::handshake(tcp) - }).then(move |res| { + let tcp = tcp + .then(|res| { + let tcp = res.unwrap(); + client::handshake(tcp) + }) + .then(move |res| { let rsps = rsps; let (mut client, h2) = res.unwrap(); let request = Request::builder() @@ -120,7 +144,9 @@ fn hammer_client_concurrency() { let (response, mut stream) = client.send_request(request, false).unwrap(); stream.send_trailers(HeaderMap::new()).unwrap(); - tokio::spawn(h2.map_err(|e| panic!("client conn error: {:?}", e))); + tokio::spawn(async move { + h2.await.unwrap(); + }); response .and_then(|response| { @@ -139,7 +165,8 @@ fn hammer_client_concurrency() { }) }); - tokio::run(tcp); + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(tcp); println!("...done"); } diff --git a/tests/h2-tests/tests/ping_pong.rs b/tests/h2-tests/tests/ping_pong.rs index f3e194fc5..3cdec5919 100644 --- a/tests/h2-tests/tests/ping_pong.rs +++ b/tests/h2-tests/tests/ping_pong.rs @@ -1,205 +1,188 @@ -use h2_support::prelude::*; +#![feature(async_await)] + +use futures::channel::oneshot; +use futures::future::join; +use futures::{StreamExt, TryStreamExt}; use h2_support::assert_ping; +use h2_support::prelude::*; -#[test] -fn recv_single_ping() { +#[tokio::test] +async fn recv_single_ping() { let _ = env_logger::try_init(); - let (m, mock) = mock::new(); + let (m, mut mock) = mock::new(); // Create the handshake - let h2 = client::handshake(m) - .unwrap() - .and_then(|(client, conn)| { - conn.unwrap() - .map(|c| (client, c)) - }); - - let mock = mock.assert_client_handshake() - .unwrap() - .and_then(|(_, mut mock)| { - let frame = frame::Ping::new(Default::default()); - mock.send(frame.into()).unwrap(); - - mock.into_future().unwrap() - }) - .and_then(|(frame, _)| { - let pong = assert_ping!(frame.unwrap()); - - // Payload is correct - assert_eq!(*pong.payload(), <[u8; 8]>::default()); - - // Is ACK - assert!(pong.is_ack()); - - Ok(()) - }); - - let _ = h2.join(mock).wait().unwrap(); + let h2 = async move { + let (client, conn) = client::handshake(m).await.unwrap(); + let c = conn.await.unwrap(); + (client, c) + }; + + let mock = async move { + let _ = mock.assert_client_handshake().await; + let frame = frame::Ping::new(Default::default()); + mock.send(frame.into()).await.unwrap(); + let frame = mock.next().await.unwrap(); + + let pong = assert_ping!(frame.unwrap()); + + // Payload is correct + assert_eq!(*pong.payload(), <[u8; 8]>::default()); + + // Is ACK + assert!(pong.is_ack()); + }; + + join(mock, h2).await; } -#[test] -fn recv_multiple_pings() { +#[tokio::test] +async fn recv_multiple_pings() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); - - let client = client.assert_server_handshake() - .expect("client handshake") - .recv_settings() - .send_frame(frames::ping([1; 8])) - .send_frame(frames::ping([2; 8])) - .recv_frame(frames::ping([1; 8]).pong()) - .recv_frame(frames::ping([2; 8]).pong()) - .close(); - - let srv = server::handshake(io) - .expect("handshake") - .and_then(|srv| { - // future of first request, which never comes - srv.into_future().unwrap() - }); - - srv.join(client).wait().expect("wait"); + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client.send_frame(frames::ping([1; 8])).await; + client.send_frame(frames::ping([2; 8])).await; + client.recv_frame(frames::ping([1; 8]).pong()).await; + client.recv_frame(frames::ping([2; 8]).pong()).await; + }; + + let srv = async move { + let mut s = server::handshake(io).await.expect("handshake"); + assert!(s.next().await.is_none()); + }; + + join(client, srv).await; } -#[test] -fn pong_has_highest_priority() { +#[tokio::test] +async fn pong_has_highest_priority() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); + let (io, mut client) = mock::new(); let data = Bytes::from(vec![0; 16_384]); - - let client = client.assert_server_handshake() - .expect("client handshake") - .recv_settings() - .send_frame( - frames::headers(1) - .request("POST", "https://http2.akamai.com/") - ) - .send_frame(frames::data(1, data.clone()).eos()) - .send_frame(frames::ping([1; 8])) - .recv_frame(frames::ping([1; 8]).pong()) - .recv_frame(frames::headers(1).response(200).eos()) - .close(); - - let srv = server::handshake(io) - .expect("handshake") - .and_then(|srv| { - // future of first request - srv.into_future().unwrap() - }).and_then(move |(reqstream, srv)| { - let (req, mut stream) = reqstream.expect("request"); - assert_eq!(req.method(), "POST"); - let body = req.into_parts().1; - - body.concat2() - .expect("body") - .and_then(move |body| { - assert_eq!(body.len(), data.len()); - let res = Response::builder() - .status(200) - .body(()) - .unwrap(); - stream.send_response(res, true).expect("response"); - srv.into_future().unwrap() - }) - }); - - srv.join(client).wait().expect("wait"); + let data_clone = data.clone(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame(frames::headers(1).request("POST", "https://http2.akamai.com/")) + .await; + client.send_frame(frames::data(1, data_clone).eos()).await; + client.send_frame(frames::ping([1; 8])).await; + client.recv_frame(frames::ping([1; 8]).pong()).await; + client + .recv_frame(frames::headers(1).response(200).eos()) + .await; + }; + + let srv = async move { + let mut s = server::handshake(io).await.expect("handshake"); + let (req, mut stream) = s.next().await.unwrap().unwrap(); + assert_eq!(req.method(), "POST"); + let body = req.into_parts().1; + + let body = body.try_concat().await.expect("body"); + assert_eq!(body.len(), data.len()); + let res = Response::builder().status(200).body(()).unwrap(); + stream.send_response(res, true).expect("response"); + assert!(s.next().await.is_none()); + }; + + join(client, srv).await; } -#[test] -fn user_ping_pong() { +#[tokio::test] +async fn user_ping_pong() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - - let srv = srv.assert_client_handshake() - .expect("srv handshake") - .recv_settings() - .recv_frame(frames::ping(frame::Ping::USER)) - .send_frame(frames::ping(frame::Ping::USER).pong()) - .recv_frame(frames::go_away(0)) - .recv_eof(); - - let client = client::handshake(io) - .expect("client handshake") - .and_then(|(client, conn)| { - // yield once so we can ack server settings - conn - .drive(util::yield_once()) - .map(move |(conn, ())| (client, conn)) - }) - .and_then(|(client, mut conn)| { - // `ping_pong()` method conflict with mock future ext trait. - let mut ping_pong = client::Connection::ping_pong(&mut conn) - .expect("taking ping_pong"); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::ping(frame::Ping::USER)).await; + srv.send_frame(frames::ping(frame::Ping::USER).pong()).await; + srv.recv_frame(frames::go_away(0)).await; + srv.recv_eof().await; + }; + + let client = async move { + let (client, mut conn) = client::handshake(io).await.expect("client handshake"); + // yield once so we can ack server settings + conn.drive(util::yield_once()).await; + // `ping_pong()` method conflict with mock future ext trait. + let mut ping_pong = client::Connection::ping_pong(&mut conn).expect("taking ping_pong"); + ping_pong.send_ping(Ping::opaque()).expect("send ping"); + + // multiple pings results in a user error... + assert_eq!( ping_pong .send_ping(Ping::opaque()) - .expect("send ping"); - - // multiple pings results in a user error... - assert_eq!( - ping_pong.send_ping(Ping::opaque()).expect_err("ping 2").to_string(), - "user error: send_ping before received previous pong", - "send_ping while ping pending is a user error", - ); - - conn - .drive(futures::future::poll_fn(move || { - ping_pong.poll_pong() - })) - .and_then(move |(conn, _pong)| { - drop(client); - conn.expect("client") - }) - }); - - client.join(srv).wait().expect("wait"); + .expect_err("ping 2") + .to_string(), + "user error: send_ping before received previous pong", + "send_ping while ping pending is a user error", + ); + + conn.drive(futures::future::poll_fn(move |cx| ping_pong.poll_pong(cx))) + .await + .unwrap(); + drop(client); + conn.await.expect("client"); + }; + + join(srv, client).await; } -#[test] -fn user_notifies_when_connection_closes() { +#[tokio::test] +async fn user_notifies_when_connection_closes() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - - let srv = srv.assert_client_handshake() - .expect("srv handshake") - .recv_settings(); - - let client = client::handshake(io) - .expect("client handshake") - .and_then(|(client, conn)| { - // yield once so we can ack server settings - conn - .drive(util::yield_once()) - .map(move |(conn, ())| (client, conn)) - }) - .map(|(_client, conn)| conn); - - let (mut client, srv) = client.join(srv).wait().expect("wait"); + let (io, mut srv) = mock::new(); + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv + }; + + let client = async move { + let (_client, mut conn) = client::handshake(io).await.expect("client handshake"); + // yield once so we can ack server settings + conn.drive(util::yield_once()).await; + conn + }; + + let (srv, mut client) = join(srv, client).await; // `ping_pong()` method conflict with mock future ext trait. - let mut ping_pong = client::Connection::ping_pong(&mut client) - .expect("taking ping_pong"); + let mut ping_pong = client::Connection::ping_pong(&mut client).expect("taking ping_pong"); // Spawn a thread so we can park a task waiting on `poll_pong`, and then // drop the client and be sure the parked task is notified... - let t = thread::spawn(move || { - poll_fn(|| { ping_pong.poll_pong() }) - .wait() + let (tx, rx) = oneshot::channel(); + tokio::spawn(async move { + poll_fn(|cx| ping_pong.poll_pong(cx)) + .await .expect_err("poll_pong should error"); - ping_pong + tx.send(ping_pong).unwrap(); }); // Sleep to let the ping thread park its task... - thread::sleep(Duration::from_millis(50)); + idle_ms(50).await; drop(client); drop(srv); - let mut ping_pong = t.join().expect("ping pong thread join"); + let mut ping_pong = rx.await.expect("ping pong spawn join"); // Now that the connection is closed, also test `send_ping` errors... assert_eq!( - ping_pong.send_ping(Ping::opaque()).expect_err("send_ping").to_string(), + ping_pong + .send_ping(Ping::opaque()) + .expect_err("send_ping") + .to_string(), "broken pipe", ); } diff --git a/tests/h2-tests/tests/prioritization.rs b/tests/h2-tests/tests/prioritization.rs index 7cd197afd..dcfdf3294 100644 --- a/tests/h2-tests/tests/prioritization.rs +++ b/tests/h2-tests/tests/prioritization.rs @@ -1,8 +1,13 @@ -use h2_support::{DEFAULT_WINDOW_SIZE}; +#![feature(async_await)] + +use futures::future::join; +use futures::{FutureExt, StreamExt}; use h2_support::prelude::*; +use h2_support::DEFAULT_WINDOW_SIZE; +use std::task::Context; -#[test] -fn single_stream_send_large_body() { +#[tokio::test] +async fn single_stream_send_large_body() { let _ = env_logger::try_init(); let payload = [0; 1024]; @@ -12,8 +17,8 @@ fn single_stream_send_large_body() { .write(frames::SETTINGS_ACK) .write(&[ // POST / - 0, 0, 16, 1, 4, 0, 0, 0, 1, 131, 135, 65, 139, 157, 41, - 172, 75, 143, 168, 233, 25, 151, 33, 233, 132, + 0, 0, 16, 1, 4, 0, 0, 0, 1, 131, 135, 65, 139, 157, 41, 172, 75, 143, 168, 233, 25, 151, + 33, 233, 132, ]) .write(&[ // DATA @@ -24,16 +29,15 @@ fn single_stream_send_large_body() { .read(&[0, 0, 1, 1, 5, 0, 0, 0, 1, 0x89]) .build(); - let notify = MockNotify::new(); - let (mut client, mut h2) = client::handshake(mock).wait().unwrap(); + let (mut client, mut h2) = client::handshake(mock).await.unwrap(); + let waker = futures::task::noop_waker(); + let mut cx = Context::from_waker(&waker); // Poll h2 once to get notifications loop { // Run the connection until all work is done, this handles processing // the handshake. - notify.with(|| h2.poll()).unwrap(); - - if !notify.is_notified() { + if !h2.poll_unpin(&mut cx).is_ready() { break; } } @@ -55,80 +59,78 @@ fn single_stream_send_large_body() { // Send the data stream.send_data(payload[..].into(), true).unwrap(); - assert!(notify.is_notified()); - // Get the response - let resp = h2.run(response).unwrap(); + let resp = h2.run(response).await.unwrap(); assert_eq!(resp.status(), StatusCode::NO_CONTENT); - h2.wait().unwrap(); + h2.await.unwrap(); } -#[test] -fn multiple_streams_with_payload_greater_than_default_window() { +#[tokio::test] +async fn multiple_streams_with_payload_greater_than_default_window() { let _ = env_logger::try_init(); - let payload = vec![0; 16384*5-1]; - - let (io, srv) = mock::new(); - - let srv = srv.assert_client_handshake().unwrap() - .recv_settings() - .recv_frame( - frames::headers(1).request("POST", "https://http2.akamai.com/") - ) - .recv_frame( - frames::headers(3).request("POST", "https://http2.akamai.com/") - ) - .recv_frame( - frames::headers(5).request("POST", "https://http2.akamai.com/") - ) - .recv_frame(frames::data(1, &payload[0..16_384])) - .recv_frame(frames::data(1, &payload[16_384..(16_384*2)])) - .recv_frame(frames::data(1, &payload[(16_384*2)..(16_384*3)])) - .recv_frame(frames::data(1, &payload[(16_384*3)..(16_384*4-1)])) - .send_frame(frames::settings()) - .recv_frame(frames::settings_ack()) - .send_frame(frames::headers(1).response(200).eos()) - .send_frame(frames::headers(3).response(200).eos()) - .send_frame(frames::headers(5).response(200).eos()) - .close(); - - let client = client::handshake(io).unwrap() - .and_then(|(mut client, conn)| { - let request1 = Request::post("https://http2.akamai.com/").body(()).unwrap(); - let request2 = Request::post("https://http2.akamai.com/").body(()).unwrap(); - let request3 = Request::post("https://http2.akamai.com/").body(()).unwrap(); - let (response1, mut stream1) = client.send_request(request1, false).unwrap(); - let (_response2, mut stream2) = client.send_request(request2, false).unwrap(); - let (_response3, mut stream3) = client.send_request(request3, false).unwrap(); - - // The capacity should be immediately - // allocated to default window size (smaller than payload) - stream1.reserve_capacity(payload.len()); - assert_eq!(stream1.capacity(), DEFAULT_WINDOW_SIZE); - - stream2.reserve_capacity(payload.len()); - assert_eq!(stream2.capacity(), 0); - - stream3.reserve_capacity(payload.len()); - assert_eq!(stream3.capacity(), 0); - - stream1.send_data(payload[..].into(), true).unwrap(); - - // hold onto streams so they don't close - // stream1 doesn't close because response1 is used - conn.drive(response1.expect("response")).map(|c| (c, client, stream2, stream3)) - }) - .and_then(|((conn, _res), client, stream2, stream3)| { - conn.expect("client").map(|c| (c, client, stream2, stream3)) - }); - - srv.join(client).wait().unwrap(); + let payload = vec![0; 16384 * 5 - 1]; + let payload_clone = payload.clone(); + + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://http2.akamai.com/")) + .await; + srv.recv_frame(frames::headers(3).request("POST", "https://http2.akamai.com/")) + .await; + srv.recv_frame(frames::headers(5).request("POST", "https://http2.akamai.com/")) + .await; + srv.recv_frame(frames::data(1, &payload[0..16_384])).await; + srv.recv_frame(frames::data(1, &payload[16_384..(16_384 * 2)])) + .await; + srv.recv_frame(frames::data(1, &payload[(16_384 * 2)..(16_384 * 3)])) + .await; + srv.recv_frame(frames::data(1, &payload[(16_384 * 3)..(16_384 * 4 - 1)])) + .await; + srv.send_frame(frames::settings()).await; + srv.recv_frame(frames::settings_ack()).await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + srv.send_frame(frames::headers(3).response(200).eos()).await; + srv.send_frame(frames::headers(5).response(200).eos()).await; + }; + + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.unwrap(); + let request1 = Request::post("https://http2.akamai.com/").body(()).unwrap(); + let request2 = Request::post("https://http2.akamai.com/").body(()).unwrap(); + let request3 = Request::post("https://http2.akamai.com/").body(()).unwrap(); + let (response1, mut stream1) = client.send_request(request1, false).unwrap(); + let (_response2, mut stream2) = client.send_request(request2, false).unwrap(); + let (_response3, mut stream3) = client.send_request(request3, false).unwrap(); + + // The capacity should be immediately + // allocated to default window size (smaller than payload) + stream1.reserve_capacity(payload_clone.len()); + assert_eq!(stream1.capacity(), DEFAULT_WINDOW_SIZE); + + stream2.reserve_capacity(payload_clone.len()); + assert_eq!(stream2.capacity(), 0); + + stream3.reserve_capacity(payload_clone.len()); + assert_eq!(stream3.capacity(), 0); + + stream1.send_data(payload_clone[..].into(), true).unwrap(); + + // hold onto streams so they don't close + // stream1 doesn't close because response1 is used + let _res = conn.drive(response1).await.expect("response"); + conn.await.expect("client"); + }; + + join(srv, client).await; } -#[test] -fn single_stream_send_extra_large_body_multi_frames_one_buffer() { +#[tokio::test] +async fn single_stream_send_extra_large_body_multi_frames_one_buffer() { let _ = env_logger::try_init(); let payload = vec![0; 32_768]; @@ -138,8 +140,8 @@ fn single_stream_send_extra_large_body_multi_frames_one_buffer() { .write(frames::SETTINGS_ACK) .write(&[ // POST / - 0, 0, 16, 1, 4, 0, 0, 0, 1, 131, 135, 65, 139, 157, 41, - 172, 75, 143, 168, 233, 25, 151, 33, 233, 132, + 0, 0, 16, 1, 4, 0, 0, 0, 1, 131, 135, 65, 139, 157, 41, 172, 75, 143, 168, 233, 25, 151, + 33, 233, 132, ]) .write(&[ // DATA @@ -155,16 +157,15 @@ fn single_stream_send_extra_large_body_multi_frames_one_buffer() { .read(&[0, 0, 1, 1, 5, 0, 0, 0, 1, 0x89]) .build(); - let notify = MockNotify::new(); - let (mut client, mut h2) = client::handshake(mock).wait().unwrap(); + let (mut client, mut h2) = client::handshake(mock).await.unwrap(); + let waker = futures::task::noop_waker(); + let mut cx = Context::from_waker(&waker); // Poll h2 once to get notifications loop { // Run the connection until all work is done, this handles processing // the handshake. - notify.with(|| h2.poll()).unwrap(); - - if !notify.is_notified() { + if !h2.poll_unpin(&mut cx).is_ready() { break; } } @@ -185,28 +186,26 @@ fn single_stream_send_extra_large_body_multi_frames_one_buffer() { // Send the data stream.send_data(payload.into(), true).unwrap(); - assert!(notify.is_notified()); - // Get the response - let resp = h2.run(response).unwrap(); + let resp = h2.run(response).await.unwrap(); assert_eq!(resp.status(), StatusCode::NO_CONTENT); - h2.wait().unwrap(); + h2.await.unwrap(); } -#[test] -fn single_stream_send_body_greater_than_default_window() { +#[tokio::test] +async fn single_stream_send_body_greater_than_default_window() { let _ = env_logger::try_init(); - let payload = vec![0; 16384*5-1]; + let payload = vec![0; 16384 * 5 - 1]; let mock = mock_io::Builder::new() .handshake() .write(frames::SETTINGS_ACK) .write(&[ // POST / - 0, 0, 16, 1, 4, 0, 0, 0, 1, 131, 135, 65, 139, 157, 41, - 172, 75, 143, 168, 233, 25, 151, 33, 233, 132, + 0, 0, 16, 1, 4, 0, 0, 0, 1, 131, 135, 65, 139, 157, 41, 172, 75, 143, 168, 233, 25, 151, + 33, 233, 132, ]) .write(&[ // DATA @@ -217,41 +216,38 @@ fn single_stream_send_body_greater_than_default_window() { // DATA 0, 64, 0, 0, 0, 0, 0, 0, 1, ]) - .write(&payload[16_384..(16_384*2)]) + .write(&payload[16_384..(16_384 * 2)]) .write(&[ // DATA 0, 64, 0, 0, 0, 0, 0, 0, 1, ]) - .write(&payload[(16_384*2)..(16_384*3)]) + .write(&payload[(16_384 * 2)..(16_384 * 3)]) .write(&[ // DATA 0, 63, 255, 0, 0, 0, 0, 0, 1, ]) - .write(&payload[(16_384*3)..(16_384*4-1)]) - + .write(&payload[(16_384 * 3)..(16_384 * 4 - 1)]) // Read window update .read(&[0, 0, 4, 8, 0, 0, 0, 0, 0, 0, 0, 64, 0]) .read(&[0, 0, 4, 8, 0, 0, 0, 0, 1, 0, 0, 64, 0]) - .write(&[ // DATA 0, 64, 0, 0, 1, 0, 0, 0, 1, ]) - .write(&payload[(16_384*4-1)..(16_384*5-1)]) + .write(&payload[(16_384 * 4 - 1)..(16_384 * 5 - 1)]) // Read response .read(&[0, 0, 1, 1, 5, 0, 0, 0, 1, 0x89]) .build(); - let notify = MockNotify::new(); - let (mut client, mut h2) = client::handshake(mock).wait().unwrap(); + let (mut client, mut h2) = client::handshake(mock).await.unwrap(); + let waker = futures::task::noop_waker(); + let mut cx = Context::from_waker(&waker); // Poll h2 once to get notifications loop { // Run the connection until all work is done, this handles processing // the handshake. - notify.with(|| h2.poll()).unwrap(); - - if !notify.is_notified() { + if !h2.poll_unpin(&mut cx).is_ready() { break; } } @@ -268,9 +264,7 @@ fn single_stream_send_body_greater_than_default_window() { loop { // Run the connection until all work is done, this handles processing // the handshake. - notify.with(|| h2.poll()).unwrap(); - - if !notify.is_notified() { + if !h2.poll_unpin(&mut cx).is_ready() { break; } } @@ -278,17 +272,15 @@ fn single_stream_send_body_greater_than_default_window() { // Send the data stream.send_data(payload.into(), true).unwrap(); - assert!(notify.is_notified()); - // Get the response - let resp = h2.run(response).unwrap(); + let resp = h2.run(response).await.unwrap(); assert_eq!(resp.status(), StatusCode::NO_CONTENT); - h2.wait().unwrap(); + h2.await.unwrap(); } -#[test] -fn single_stream_send_extra_large_body_multi_frames_multi_buffer() { +#[tokio::test] +async fn single_stream_send_extra_large_body_multi_frames_multi_buffer() { let _ = env_logger::try_init(); let payload = vec![0; 32_768]; @@ -300,24 +292,20 @@ fn single_stream_send_extra_large_body_multi_frames_multi_buffer() { .read(frames::SETTINGS) // Add wait to force the data writes to chill .wait(Duration::from_millis(10)) - // Rest - .write(&[ // POST / - 0, 0, 16, 1, 4, 0, 0, 0, 1, 131, 135, 65, 139, 157, 41, - 172, 75, 143, 168, 233, 25, 151, 33, 233, 132, + 0, 0, 16, 1, 4, 0, 0, 0, 1, 131, 135, 65, 139, 157, 41, 172, 75, 143, 168, 233, 25, 151, + 33, 233, 132, ]) .write(&[ // DATA 0, 64, 0, 0, 0, 0, 0, 0, 1, ]) .write(&payload[0..16_384]) - .write(frames::SETTINGS_ACK) .read(frames::SETTINGS_ACK) .wait(Duration::from_millis(10)) - .write(&[ // DATA 0, 64, 0, 0, 1, 0, 0, 0, 1, @@ -327,7 +315,7 @@ fn single_stream_send_extra_large_body_multi_frames_multi_buffer() { .read(&[0, 0, 1, 1, 5, 0, 0, 0, 1, 0x89]) .build(); - let (mut client, mut h2) = client::handshake(mock).wait().unwrap(); + let (mut client, mut h2) = client::handshake(mock).await.unwrap(); let request = Request::builder() .method(Method::POST) @@ -346,92 +334,79 @@ fn single_stream_send_extra_large_body_multi_frames_multi_buffer() { stream.send_data(payload.into(), true).unwrap(); // Get the response - let resp = h2.run(response).unwrap(); + let resp = h2.run(response).await.unwrap(); assert_eq!(resp.status(), StatusCode::NO_CONTENT); - h2.wait().unwrap(); + h2.await.unwrap(); } -#[test] -fn send_data_receive_window_update() { +#[tokio::test] +async fn send_data_receive_window_update() { let _ = env_logger::try_init(); - let (m, mock) = mock::new(); + let (m, mut mock) = mock::new(); - let h2 = client::handshake(m) - .unwrap() - .and_then(|(mut client, h2)| { - let request = Request::builder() - .method(Method::POST) - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); + let h2 = async move { + let (mut client, mut h2) = client::handshake(m).await.unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); - // Send request - let (response, mut stream) = client.send_request(request, false).unwrap(); + // Send request + let (_response, mut stream) = client.send_request(request, false).unwrap(); - // Send data frame - stream.send_data("hello".into(), false).unwrap(); + // Send data frame + stream.send_data("hello".into(), false).unwrap(); - stream.reserve_capacity(frame::DEFAULT_INITIAL_WINDOW_SIZE as usize); + stream.reserve_capacity(frame::DEFAULT_INITIAL_WINDOW_SIZE as usize); - // Wait for capacity - h2.drive(util::wait_for_capacity( + // Wait for capacity + let mut stream = h2 + .drive(util::wait_for_capacity( stream, frame::DEFAULT_INITIAL_WINDOW_SIZE as usize, - ).map(|s| (response, s))) - }) - .and_then(|(h2, (_r, mut stream))| { - let payload = vec![0; frame::DEFAULT_INITIAL_WINDOW_SIZE as usize]; - stream.send_data(payload.into(), true).unwrap(); - - // keep `stream` from being dropped in order to prevent - // it from sending an RST_STREAM frame. - std::mem::forget(stream); - h2.unwrap() - }); - - let mock = mock.assert_client_handshake().unwrap() - .and_then(|(_, mock)| mock.into_future().unwrap()) - .and_then(|(frame, mock)| { - let request = assert_headers!(frame.unwrap()); - assert!(!request.is_end_stream()); - mock.into_future().unwrap() - }) - .and_then(|(frame, mut mock)| { - let data = assert_data!(frame.unwrap()); - - // Update the windows - let len = data.payload().len(); - let f = frame::WindowUpdate::new(StreamId::zero(), len as u32); - mock.send(f.into()).unwrap(); - - let f = frame::WindowUpdate::new(data.stream_id(), len as u32); - mock.send(f.into()).unwrap(); - - mock.into_future().unwrap() - }) - // TODO: Dedup the following lines - .and_then(|(frame, mock)| { - let data = assert_data!(frame.unwrap()); - assert_eq!(data.payload().len(), frame::DEFAULT_MAX_FRAME_SIZE as usize); - mock.into_future().unwrap() - }) - .and_then(|(frame, mock)| { - let data = assert_data!(frame.unwrap()); - assert_eq!(data.payload().len(), frame::DEFAULT_MAX_FRAME_SIZE as usize); - mock.into_future().unwrap() - }) - .and_then(|(frame, mock)| { + )) + .await; + let payload = vec![0; frame::DEFAULT_INITIAL_WINDOW_SIZE as usize]; + stream.send_data(payload.into(), true).unwrap(); + + // keep `stream` from being dropped in order to prevent + // it from sending an RST_STREAM frame. + std::mem::forget(stream); + h2.await.unwrap(); + }; + + let mock = async move { + let _ = mock.assert_client_handshake().await; + + let frame = mock.next().await.unwrap(); + let request = assert_headers!(frame.unwrap()); + assert!(!request.is_end_stream()); + let frame = mock.next().await.unwrap(); + let data = assert_data!(frame.unwrap()); + + // Update the windows + let len = data.payload().len(); + let f = frame::WindowUpdate::new(StreamId::zero(), len as u32); + mock.send(f.into()).await.unwrap(); + + let f = frame::WindowUpdate::new(data.stream_id(), len as u32); + mock.send(f.into()).await.unwrap(); + + for _ in 0..3usize { + let frame = mock.next().await.unwrap(); let data = assert_data!(frame.unwrap()); assert_eq!(data.payload().len(), frame::DEFAULT_MAX_FRAME_SIZE as usize); - mock.into_future().unwrap() - }) - .and_then(|(frame, _)| { - let data = assert_data!(frame.unwrap()); - assert_eq!(data.payload().len(), (frame::DEFAULT_MAX_FRAME_SIZE-1) as usize); - Ok(()) - }); - - let _ = h2.join(mock).wait().unwrap(); + } + let frame = mock.next().await.unwrap(); + let data = assert_data!(frame.unwrap()); + assert_eq!( + data.payload().len(), + (frame::DEFAULT_MAX_FRAME_SIZE - 1) as usize + ); + }; + + join(mock, h2).await; } diff --git a/tests/h2-tests/tests/push_promise.rs b/tests/h2-tests/tests/push_promise.rs index 7c1eef443..f37ee3178 100644 --- a/tests/h2-tests/tests/push_promise.rs +++ b/tests/h2-tests/tests/push_promise.rs @@ -1,310 +1,346 @@ +#![feature(async_await)] +use futures::future::join; +use futures::{StreamExt, TryStreamExt}; use h2_support::prelude::*; -#[test] -fn recv_push_works() { +#[tokio::test] +async fn recv_push_works() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - let mock = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let (io, mut srv) = mock::new(); + let mock = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") .eos(), ) - .send_frame(frames::headers(1).response(404)) - .send_frame(frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css")) - .send_frame(frames::data(1, "").eos()) - .send_frame(frames::headers(2).response(200)) - .send_frame(frames::data(2, "promised_data").eos()); - - let h2 = client::handshake(io).unwrap().and_then(|(mut client, h2)| { + .await; + srv.send_frame(frames::headers(1).response(404)).await; + srv.send_frame( + frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css"), + ) + .await; + srv.send_frame(frames::data(1, "").eos()).await; + srv.send_frame(frames::headers(2).response(200)).await; + srv.send_frame(frames::data(2, "promised_data").eos()).await; + }; + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); let request = Request::builder() .method(Method::GET) .uri("https://http2.akamai.com/") .body(()) .unwrap(); - let (mut resp, _) = client - .send_request(request, true) - .unwrap(); + let (mut resp, _) = client.send_request(request, true).unwrap(); let pushed = resp.push_promises(); - let check_resp_status = resp.unwrap().map(|resp| { - assert_eq!(resp.status(), StatusCode::NOT_FOUND) - }); - let check_pushed_request = pushed.and_then(|headers| { - let (request, response) = headers.into_parts(); - assert_eq!(request.into_parts().0.method, Method::GET); - response - }); - let check_pushed_response = check_pushed_request.and_then( - |resp| { - assert_eq!(resp.status(), StatusCode::OK); - resp.into_body().concat2().map(|b| assert_eq!(b, "promised_data")) - } - ).collect().unwrap().map(|ps| { + let check_resp_status = async move { + let resp = resp.await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); + }; + let check_pushed_response = async move { + let p = pushed.and_then(|headers| { + async move { + let (request, response) = headers.into_parts(); + assert_eq!(request.into_parts().0.method, Method::GET); + let resp = response.await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let b = resp.into_body().try_concat().await.unwrap(); + assert_eq!(b, "promised_data"); + Ok(()) + } + }); + let ps: Vec<_> = p.collect().await; assert_eq!(1, ps.len()) - }); - h2.drive(check_resp_status.join(check_pushed_response)) - }); + }; + + h2.drive(join(check_resp_status, check_pushed_response)) + .await; + }; - h2.join(mock).wait().unwrap(); + join(mock, h2).await; } -#[test] -fn pushed_streams_arent_dropped_too_early() { +#[tokio::test] +async fn pushed_streams_arent_dropped_too_early() { // tests that by default, received push promises work let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - let mock = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let (io, mut srv) = mock::new(); + let mock = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") .eos(), ) - .send_frame(frames::headers(1).response(404)) - .send_frame(frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css")) - .send_frame(frames::push_promise(1, 4).request("GET", "https://http2.akamai.com/style2.css")) - .send_frame(frames::data(1, "").eos()) - .idle_ms(10) - .send_frame(frames::headers(2).response(200)) - .send_frame(frames::headers(4).response(200).eos()) - .send_frame(frames::data(2, "").eos()) - .recv_frame(frames::go_away(4)); - - let h2 = client::handshake(io).unwrap().and_then(|(mut client, h2)| { + .await; + srv.send_frame(frames::headers(1).response(404)).await; + srv.send_frame( + frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css"), + ) + .await; + srv.send_frame( + frames::push_promise(1, 4).request("GET", "https://http2.akamai.com/style2.css"), + ) + .await; + srv.send_frame(frames::data(1, "").eos()).await; + idle_ms(10).await; + srv.send_frame(frames::headers(2).response(200)).await; + srv.send_frame(frames::headers(4).response(200).eos()).await; + srv.send_frame(frames::data(2, "").eos()).await; + srv.recv_frame(frames::go_away(4)).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); let request = Request::builder() .method(Method::GET) .uri("https://http2.akamai.com/") .body(()) .unwrap(); - let (mut resp, _) = client - .send_request(request, true) - .unwrap(); - let pushed = resp.push_promises(); - let check_status = resp.unwrap().and_then(|resp| { + let (mut resp, _) = client.send_request(request, true).unwrap(); + let mut pushed = resp.push_promises(); + let check_status = async move { + let resp = resp.await.unwrap(); assert_eq!(resp.status(), StatusCode::NOT_FOUND); - Ok(()) - }); - let check_pushed_headers = pushed.and_then(|headers| { - let (request, response) = headers.into_parts(); - assert_eq!(request.into_parts().0.method, Method::GET); - response - }); - let check_pushed = check_pushed_headers.map( - |resp| assert_eq!(resp.status(), StatusCode::OK) - ).collect().unwrap().and_then(|ps| { - assert_eq!(2, ps.len()); - Ok(()) - }); - h2.drive(check_status.join(check_pushed)).and_then(|(conn, _)| conn.expect("client")) - }); - - h2.join(mock).wait().unwrap(); + }; + + let check_pushed = async move { + let mut count = 0; + while let Some(headers) = pushed.next().await { + let (request, response) = headers.unwrap().into_parts(); + assert_eq!(request.into_parts().0.method, Method::GET); + let resp = response.await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + count += 1; + } + assert_eq!(2, count); + }; + + drop(client); + + h2.drive(join(check_pushed, check_status)).await; + h2.await.expect("client"); + }; + + join(mock, h2).await; } -#[test] -fn recv_push_when_push_disabled_is_conn_error() { +#[tokio::test] +async fn recv_push_when_push_disabled_is_conn_error() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - let mock = srv.assert_client_handshake() - .unwrap() - .ignore_settings() - .recv_frame( + let (io, mut srv) = mock::new(); + let mock = async move { + let _ = srv.assert_client_handshake().await; + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") .eos(), ) - .send_frame(frames::push_promise(1, 3).request("GET", "https://http2.akamai.com/style.css")) - .send_frame(frames::headers(1).response(200).eos()) - .recv_frame(frames::go_away(0).protocol_error()); - - let h2 = client::Builder::new() - .enable_push(false) - .handshake::<_, Bytes>(io) - .unwrap() - .and_then(|(mut client, h2)| { - let request = Request::builder() - .method(Method::GET) - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); - - let req = client.send_request(request, true).unwrap().0.then(|res| { - let err = res.unwrap_err(); - assert_eq!( - err.to_string(), - "protocol error: unspecific protocol error detected" - ); - Ok::<(), ()>(()) - }); - - // client should see a protocol error - let conn = h2.then(|res| { - let err = res.unwrap_err(); - assert_eq!( - err.to_string(), - "protocol error: unspecific protocol error detected" - ); - Ok::<(), ()>(()) - }); - - conn.unwrap().join(req) - }); + .await; + srv.send_frame( + frames::push_promise(1, 3).request("GET", "https://http2.akamai.com/style.css"), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + srv.recv_frame(frames::go_away(0).protocol_error()).await; + }; + + let h2 = async move { + let (mut client, h2) = client::Builder::new() + .enable_push(false) + .handshake::<_, Bytes>(io) + .await + .unwrap(); + let request = Request::builder() + .method(Method::GET) + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); - h2.join(mock).wait().unwrap(); + let req = async move { + let res = client.send_request(request, true).unwrap().0.await; + let err = res.unwrap_err(); + assert_eq!( + err.to_string(), + "protocol error: unspecific protocol error detected" + ); + }; + + // client should see a protocol error + let conn = async move { + let res = h2.await; + let err = res.unwrap_err(); + assert_eq!( + err.to_string(), + "protocol error: unspecific protocol error detected" + ); + }; + + join(conn, req).await; + }; + + join(mock, h2).await; } -#[test] -fn pending_push_promises_reset_when_dropped() { +#[tokio::test] +async fn pending_push_promises_reset_when_dropped() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let (io, mut srv) = mock::new(); + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") .eos(), ) - .send_frame( - frames::push_promise(1, 2) - .request("GET", "https://http2.akamai.com/style.css") + .await; + srv.send_frame( + frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css"), ) - .send_frame(frames::headers(1).response(200).eos()) - .recv_frame(frames::reset(2).cancel()) - .close(); + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + srv.recv_frame(frames::reset(2).cancel()).await; + }; - let client = client::handshake(io).unwrap().and_then(|(mut client, conn)| { + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.unwrap(); let request = Request::builder() .method(Method::GET) .uri("https://http2.akamai.com/") .body(()) .unwrap(); - let req = client - .send_request(request, true) - .unwrap() - .0.expect("response") - .and_then(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - Ok(()) - }); + let req = async { + let resp = client + .send_request(request, true) + .unwrap() + .0 + .await + .expect("response"); + assert_eq!(resp.status(), StatusCode::OK); + }; - conn.drive(req) - .and_then(move |(conn, _)| conn.expect("client").map(move |()| drop(client))) - }); + let _ = conn.drive(req).await; + conn.await.expect("client"); + drop(client); + }; - client.join(srv).wait().expect("wait"); + join(srv, client).await; } -#[test] -fn recv_push_promise_over_max_header_list_size() { +#[tokio::test] +async fn recv_push_promise_over_max_header_list_size() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_custom_settings( - frames::settings() - .max_header_list_size(10) - ) - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_frame_eq(settings, frames::settings().max_header_list_size(10)); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") .eos(), ) - .send_frame(frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css")) - .recv_frame(frames::reset(2).refused()) - .send_frame(frames::headers(1).response(200).eos()) - .idle_ms(10) - .close(); - - let client = client::Builder::new() - .max_header_list_size(10) - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let request = Request::builder() - .uri("https://http2.akamai.com/") - .body(()) - .unwrap(); - - let req = client + .await; + srv.send_frame( + frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css"), + ) + .await; + srv.recv_frame(frames::reset(2).refused()).await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + idle_ms(10).await; + }; + + let client = async move { + let (mut client, mut conn) = client::Builder::new() + .max_header_list_size(10) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + let request = Request::builder() + .uri("https://http2.akamai.com/") + .body(()) + .unwrap(); + + let req = async move { + let err = client .send_request(request, true) .expect("send_request") .0 - .expect_err("response") - .map(|err| { - assert_eq!( - err.reason(), - Some(Reason::REFUSED_STREAM) - ); - }); - - conn.drive(req) - .and_then(|(conn, _)| conn.expect("client")) - }); - client.join(srv).wait().expect("wait"); + .await + .expect_err("response"); + assert_eq!(err.reason(), Some(Reason::REFUSED_STREAM)); + }; + + conn.drive(req).await; + conn.await.expect("client"); + }; + join(srv, client).await; } -#[test] -fn recv_invalid_push_promise_headers_is_stream_protocol_error() { +#[tokio::test] +async fn recv_invalid_push_promise_headers_is_stream_protocol_error() { // Unsafe method or content length is stream protocol error let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - let mock = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let (io, mut srv) = mock::new(); + let mock = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") .eos(), ) - .send_frame(frames::headers(1).response(404)) - .send_frame(frames::push_promise(1, 2).request("POST", "https://http2.akamai.com/style.css")) - .send_frame( + .await; + srv.send_frame(frames::headers(1).response(404)).await; + srv.send_frame( + frames::push_promise(1, 2).request("POST", "https://http2.akamai.com/style.css"), + ) + .await; + srv.send_frame( frames::push_promise(1, 4) .request("GET", "https://http2.akamai.com/style.css") - .field(http::header::CONTENT_LENGTH, 1) + .field(http::header::CONTENT_LENGTH, 1), ) - .send_frame( + .await; + srv.send_frame( frames::push_promise(1, 6) .request("GET", "https://http2.akamai.com/style.css") - .field(http::header::CONTENT_LENGTH, 0) + .field(http::header::CONTENT_LENGTH, 0), ) - .send_frame(frames::headers(1).response(404).eos()) - .recv_frame(frames::reset(2).protocol_error()) - .recv_frame(frames::reset(4).protocol_error()) - .send_frame(frames::headers(6).response(200).eos()) - .close(); - - let h2 = client::handshake(io).unwrap().and_then(|(mut client, h2)| { + .await; + srv.send_frame(frames::headers(1).response(404).eos()).await; + srv.recv_frame(frames::reset(2).protocol_error()).await; + srv.recv_frame(frames::reset(4).protocol_error()).await; + srv.send_frame(frames::headers(6).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); let request = Request::builder() .method(Method::GET) .uri("https://http2.akamai.com/") .body(()) .unwrap(); - let (mut resp, _) = client - .send_request(request, true) - .unwrap(); - let check_pushed_request = resp.push_promises().and_then(|headers| { - headers.into_parts().1 - }); - let check_pushed_response = check_pushed_request - .collect().unwrap().map(|ps| { - // CONTENT_LENGTH = 0 is ok - assert_eq!(1, ps.len()) - }); - h2.drive(check_pushed_response) - }); - - h2.join(mock).wait().unwrap(); + let (mut resp, _) = client.send_request(request, true).unwrap(); + let check_pushed_response = async move { + let pushed = resp.push_promises(); + let p = pushed.and_then(|headers| headers.into_parts().1); + let ps: Vec<_> = p.collect().await; + // CONTENT_LENGTH = 0 is ok + assert_eq!(1, ps.len()); + }; + h2.drive(check_pushed_response).await; + }; + + join(mock, h2).await; } #[test] @@ -313,102 +349,110 @@ fn recv_push_promise_with_wrong_authority_is_stream_error() { // if server is foo.com, :authority = bar.com is stream error } -#[test] -fn recv_push_promise_skipped_stream_id() { +#[tokio::test] +async fn recv_push_promise_skipped_stream_id() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - let mock = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let (io, mut srv) = mock::new(); + let mock = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") .eos(), ) - .send_frame(frames::push_promise(1, 4).request("GET", "https://http2.akamai.com/style.css")) - .send_frame(frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css")) - .recv_frame(frames::go_away(0).protocol_error()) - .close(); + .await; + srv.send_frame( + frames::push_promise(1, 4).request("GET", "https://http2.akamai.com/style.css"), + ) + .await; + srv.send_frame( + frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css"), + ) + .await; + srv.recv_frame(frames::go_away(0).protocol_error()).await; + }; - let h2 = client::handshake(io).unwrap().and_then(|(mut client, h2)| { + let h2 = async move { + let (mut client, h2) = client::handshake(io).await.unwrap(); let request = Request::builder() .method(Method::GET) .uri("https://http2.akamai.com/") .body(()) .unwrap(); - let req = client - .send_request(request, true) - .unwrap() - .0 - .then(|res| { - assert!(res.is_err()); - Ok::<_, ()>(()) - }); - - // client should see a protocol error - let conn = h2.then(|res| { - let err = res.unwrap_err(); - assert_eq!( - err.to_string(), - "protocol error: unspecific protocol error detected" - ); - Ok::<(), ()>(()) - }); - - conn.unwrap().join(req) - }); - - h2.join(mock).wait().unwrap(); + let req = async move { + let res = client.send_request(request, true).unwrap().0.await; + assert!(res.is_err()); + }; + + // client should see a protocol error + let conn = async move { + let res = h2.await; + let err = res.unwrap_err(); + assert_eq!( + err.to_string(), + "protocol error: unspecific protocol error detected" + ); + }; + + join(conn, req).await; + }; + + join(mock, h2).await; } -#[test] -fn recv_push_promise_dup_stream_id() { +#[tokio::test] +async fn recv_push_promise_dup_stream_id() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - let mock = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let (io, mut srv) = mock::new(); + let mock = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://http2.akamai.com/") .eos(), ) - .send_frame(frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css")) - .send_frame(frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css")) - .recv_frame(frames::go_away(0).protocol_error()) - .close(); + .await; + srv.send_frame( + frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css"), + ) + .await; + srv.send_frame( + frames::push_promise(1, 2).request("GET", "https://http2.akamai.com/style.css"), + ) + .await; + srv.recv_frame(frames::go_away(0).protocol_error()).await; + }; - let h2 = client::handshake(io).unwrap().and_then(|(mut client, h2)| { + let h2 = async move { + let (mut client, h2) = client::handshake(io).await.unwrap(); let request = Request::builder() .method(Method::GET) .uri("https://http2.akamai.com/") .body(()) .unwrap(); - let req = client - .send_request(request, true) - .unwrap() - .0 - .then(|res| { - assert!(res.is_err()); - Ok::<_, ()>(()) - }); - - // client should see a protocol error - let conn = h2.then(|res| { - let err = res.unwrap_err(); - assert_eq!( - err.to_string(), - "protocol error: unspecific protocol error detected" - ); - Ok::<(), ()>(()) - }); - - conn.unwrap().join(req) - }); - - h2.join(mock).wait().unwrap(); + let req = async move { + let res = client.send_request(request, true).unwrap().0.await; + assert!(res.is_err()); + }; + + // client should see a protocol error + let conn = async move { + let res = h2.await; + let err = res.unwrap_err(); + assert_eq!( + err.to_string(), + "protocol error: unspecific protocol error detected" + ); + }; + + join(conn, req).await; + }; + + join(mock, h2).await; } diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index 8c5d11aaa..857d08172 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -1,12 +1,16 @@ +#![feature(async_await)] #![deny(warnings)] +use futures::future::{join, poll_fn}; +use futures::{StreamExt, TryStreamExt}; use h2_support::prelude::*; +use tokio::io::AsyncWriteExt; const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; -#[test] -fn read_preface_in_multiple_frames() { +#[tokio::test] +async fn read_preface_in_multiple_frames() { let _ = env_logger::try_init(); let mock = mock_io::Builder::new() @@ -18,102 +22,98 @@ fn read_preface_in_multiple_frames() { .read(SETTINGS_ACK) .build(); - let h2 = server::handshake(mock).wait().unwrap(); + let mut h2 = server::handshake(mock).await.unwrap(); - assert!(Stream::wait(h2).next().is_none()); + assert!(h2.next().await.is_none()); } -#[test] -fn server_builder_set_max_concurrent_streams() { +#[tokio::test] +async fn server_builder_set_max_concurrent_streams() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); + let (io, mut client) = mock::new(); let mut settings = frame::Settings::default(); settings.set_max_concurrent_streams(Some(1)); - let client = client - .assert_server_handshake() - .unwrap() - .recv_custom_settings(settings) - .send_frame( - frames::headers(1) - .request("GET", "https://example.com/"), - ) - .send_frame( - frames::headers(3) - .request("GET", "https://example.com/"), - ) - .send_frame(frames::data(1, &b"hello"[..]).eos(),) - .recv_frame(frames::reset(3).refused()) - .recv_frame(frames::headers(1).response(200).eos()) - .close(); + let client = async move { + let recv_settings = client.assert_server_handshake().await; + assert_frame_eq(recv_settings, settings); + client + .send_frame(frames::headers(1).request("GET", "https://example.com/")) + .await; + client + .send_frame(frames::headers(3).request("GET", "https://example.com/")) + .await; + client + .send_frame(frames::data(1, &b"hello"[..]).eos()) + .await; + client.recv_frame(frames::reset(3).refused()).await; + client + .recv_frame(frames::headers(1).response(200).eos()) + .await; + }; let mut builder = server::Builder::new(); builder.max_concurrent_streams(1); - let h2 = builder - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|srv| { - srv.into_future().unwrap().and_then(|(reqstream, srv)| { - let (req, mut stream) = reqstream.unwrap(); + let h2 = async move { + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + let (req, mut stream) = srv.next().await.unwrap().unwrap(); - assert_eq!(req.method(), &http::Method::GET); + assert_eq!(req.method(), &http::Method::GET); - let rsp = - http::Response::builder() - .status(200).body(()) - .unwrap(); - stream.send_response(rsp, true).unwrap(); + let rsp = http::Response::builder().status(200).body(()).unwrap(); + stream.send_response(rsp, true).unwrap(); - srv.into_future().unwrap().map(|_| ()) - }) - }); + assert!(srv.next().await.is_none()); + }; - h2.join(client).wait().expect("wait"); + join(client, h2).await; } -#[test] -fn serve_request() { +#[tokio::test] +async fn serve_request() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); - - let client = client - .assert_server_handshake() - .unwrap() - .recv_settings() - .send_frame( - frames::headers(1) - .request("GET", "https://example.com/") - .eos(), - ) - .recv_frame(frames::headers(1).response(200).eos()) - .close(); + let (io, mut client) = mock::new(); - let srv = server::handshake(io).expect("handshake").and_then(|srv| { - srv.into_future().unwrap().and_then(|(reqstream, srv)| { - let (req, mut stream) = reqstream.unwrap(); + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + client + .recv_frame(frames::headers(1).response(200).eos()) + .await; + }; - assert_eq!(req.method(), &http::Method::GET); + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + let (req, mut stream) = srv.next().await.unwrap().unwrap(); - let rsp = http::Response::builder().status(200).body(()).unwrap(); - stream.send_response(rsp, true).unwrap(); + assert_eq!(req.method(), &http::Method::GET); - srv.into_future().unwrap().map(|_| ()) - }) - }); + let rsp = http::Response::builder().status(200).body(()).unwrap(); + stream.send_response(rsp, true).unwrap(); + + assert!(srv.next().await.is_none()); + }; - srv.join(client).wait().expect("wait"); + join(client, srv).await; } #[test] #[ignore] fn accept_with_pending_connections_after_socket_close() {} -#[test] -fn recv_invalid_authority() { +#[tokio::test] +async fn recv_invalid_authority() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); + let (io, mut client) = mock::new(); let bad_auth = util::byte_str("not:a/good authority"); let mut bad_headers: frame::Headers = frames::headers(1) @@ -122,25 +122,25 @@ fn recv_invalid_authority() { .into(); bad_headers.pseudo_mut().authority = Some(bad_auth); - let client = client - .assert_server_handshake() - .unwrap() - .recv_settings() - .send_frame(bad_headers) - .recv_frame(frames::reset(1).protocol_error()) - .close(); + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client.send_frame(bad_headers).await; + client.recv_frame(frames::reset(1).protocol_error()).await; + }; - let srv = server::handshake(io) - .expect("handshake") - .and_then(|srv| srv.into_future().unwrap().map(|_| ())); + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + assert!(srv.next().await.is_none()); + }; - srv.join(client).wait().expect("wait"); + join(client, srv).await; } -#[test] -fn recv_connection_header() { +#[tokio::test] +async fn recv_connection_header() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); + let (io, mut client) = mock::new(); let req = |id, name, val| { frames::headers(id) @@ -149,511 +149,484 @@ fn recv_connection_header() { .eos() }; - let client = client - .assert_server_handshake() - .unwrap() - .recv_settings() - .send_frame(req(1, "connection", "foo")) - .send_frame(req(3, "keep-alive", "5")) - .send_frame(req(5, "proxy-connection", "bar")) - .send_frame(req(7, "transfer-encoding", "chunked")) - .send_frame(req(9, "upgrade", "HTTP/2.0")) - .recv_frame(frames::reset(1).protocol_error()) - .recv_frame(frames::reset(3).protocol_error()) - .recv_frame(frames::reset(5).protocol_error()) - .recv_frame(frames::reset(7).protocol_error()) - .recv_frame(frames::reset(9).protocol_error()) - .close(); - - let srv = server::handshake(io) - .expect("handshake") - .and_then(|srv| srv.into_future().unwrap()).map(|_| ()); - - srv.join(client).wait().expect("wait"); + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client.send_frame(req(1, "connection", "foo")).await; + client.send_frame(req(3, "keep-alive", "5")).await; + client.send_frame(req(5, "proxy-connection", "bar")).await; + client + .send_frame(req(7, "transfer-encoding", "chunked")) + .await; + client.send_frame(req(9, "upgrade", "HTTP/2.0")).await; + client.recv_frame(frames::reset(1).protocol_error()).await; + client.recv_frame(frames::reset(3).protocol_error()).await; + client.recv_frame(frames::reset(5).protocol_error()).await; + client.recv_frame(frames::reset(7).protocol_error()).await; + client.recv_frame(frames::reset(9).protocol_error()).await; + }; + + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + assert!(srv.next().await.is_none()); + }; + + join(client, srv).await; } -#[test] -fn sends_reset_cancel_when_req_body_is_dropped() { +#[tokio::test] +async fn sends_reset_cancel_when_req_body_is_dropped() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); - - let client = client - .assert_server_handshake() - .unwrap() - .recv_settings() - .send_frame( - frames::headers(1) - .request("POST", "https://example.com/") - ) - .recv_frame(frames::headers(1).response(200).eos()) - .recv_frame(frames::reset(1).cancel()) - .close(); + let (io, mut client) = mock::new(); - let srv = server::handshake(io).expect("handshake").and_then(|srv| { - srv.into_future().unwrap().and_then(|(reqstream, srv)| { - let (req, mut stream) = reqstream.unwrap(); + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame(frames::headers(1).request("POST", "https://example.com/")) + .await; + client + .recv_frame(frames::headers(1).response(200).eos()) + .await; + client.recv_frame(frames::reset(1).cancel()).await; + }; + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + { + let (req, mut stream) = srv.next().await.unwrap().unwrap(); assert_eq!(req.method(), &http::Method::POST); let rsp = http::Response::builder().status(200).body(()).unwrap(); stream.send_response(rsp, true).unwrap(); + } + assert!(srv.next().await.is_none()); + }; - srv.into_future().unwrap().map(|_| ()) - }) - }); - - srv.join(client).wait().expect("wait"); + join(client, srv).await; } -#[test] -fn abrupt_shutdown() { +#[tokio::test] +async fn abrupt_shutdown() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); - - let client = client - .assert_server_handshake() - .unwrap() - .recv_settings() - .send_frame( - frames::headers(1) - .request("POST", "https://example.com/") - ) - .recv_frame(frames::go_away(1).internal_error()) - .recv_eof(); - - let srv = server::handshake(io).expect("handshake").and_then(|srv| { - srv.into_future().unwrap().and_then(|(item, mut srv)| { - let (req, tx) = item.expect("server receives request"); - - let req_fut = req - .into_body() - .concat2() - .map(|_| drop(tx)) - .expect_err("request body should error") - .map(|err| { - assert_eq!( - err.reason(), - Some(Reason::INTERNAL_ERROR), - "streams should be also error with user's reason", - ); - }); - - srv.abrupt_shutdown(Reason::INTERNAL_ERROR); - - let srv_fut = futures::future::poll_fn(move || { - srv.poll_close() - }).expect("server"); - - req_fut.join(srv_fut) - }) - }); + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame(frames::headers(1).request("POST", "https://example.com/")) + .await; + client.recv_frame(frames::go_away(1).internal_error()).await; + client.recv_eof().await; + }; - srv.join(client).wait().expect("wait"); + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + let (req, tx) = srv.next().await.unwrap().expect("server receives request"); + + let req_fut = async move { + let body = req.into_body().try_concat().await; + drop(tx); + let err = body.expect_err("request body should error"); + assert_eq!( + err.reason(), + Some(Reason::INTERNAL_ERROR), + "streams should be also error with user's reason", + ); + }; + + srv.abrupt_shutdown(Reason::INTERNAL_ERROR); + + let srv_fut = async move { + poll_fn(move |cx| srv.poll_close(cx)).await.expect("server"); + }; + + join(req_fut, srv_fut).await; + }; + + join(client, srv).await; } -#[test] -fn graceful_shutdown() { +#[tokio::test] +async fn graceful_shutdown() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); - - let client = client - .assert_server_handshake() - .unwrap() - .recv_settings() - .send_frame( - frames::headers(1) - .request("GET", "https://example.com/") - .eos(), - ) + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; // 2^31 - 1 = 2147483647 // Note: not using a constant in the library because library devs // can be unsmart. - .recv_frame(frames::go_away(2147483647)) - .recv_frame(frames::ping(frame::Ping::SHUTDOWN)) - .recv_frame(frames::headers(1).response(200).eos()) + client.recv_frame(frames::go_away(2147483647)).await; + client.recv_frame(frames::ping(frame::Ping::SHUTDOWN)).await; + client + .recv_frame(frames::headers(1).response(200).eos()) + .await; // Pretend this stream was sent while the GOAWAY was in flight - .send_frame( - frames::headers(3) - .request("POST", "https://example.com/"), - ) - .send_frame(frames::ping(frame::Ping::SHUTDOWN).pong()) - .recv_frame(frames::go_away(3)) + client + .send_frame(frames::headers(3).request("POST", "https://example.com/")) + .await; + client + .send_frame(frames::ping(frame::Ping::SHUTDOWN).pong()) + .await; + client.recv_frame(frames::go_away(3)).await; // streams sent after GOAWAY receive no response - .send_frame( - frames::headers(7) - .request("GET", "https://example.com/"), - ) - .send_frame(frames::data(7, "").eos()) - .send_frame(frames::data(3, "").eos()) - .recv_frame(frames::headers(3).response(200).eos()) - .recv_eof(); - - let srv = server::handshake(io) - .expect("handshake") - .and_then(|srv| { - srv.into_future().unwrap() - }) - .and_then(|(reqstream, mut srv)| { - let (req, mut stream) = reqstream.unwrap(); + client + .send_frame(frames::headers(7).request("GET", "https://example.com/")) + .await; + client.send_frame(frames::data(7, "").eos()).await; + client.send_frame(frames::data(3, "").eos()).await; + client + .recv_frame(frames::headers(3).response(200).eos()) + .await; + client.recv_eof().await; + }; - assert_eq!(req.method(), &http::Method::GET); + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + let (req, mut stream) = srv.next().await.unwrap().unwrap(); + assert_eq!(req.method(), &http::Method::GET); + + srv.graceful_shutdown(); + + let rsp = http::Response::builder().status(200).body(()).unwrap(); + stream.send_response(rsp, true).unwrap(); + + let (req, mut stream) = srv.next().await.unwrap().unwrap(); + assert_eq!(req.method(), &http::Method::POST); + let body = req.into_parts().1; - srv.graceful_shutdown(); + let body = async move { + let buf = body.try_concat().await.unwrap(); + assert!(buf.is_empty()); - let rsp = http::Response::builder() - .status(200) - .body(()) - .unwrap(); + let rsp = http::Response::builder().status(200).body(()).unwrap(); stream.send_response(rsp, true).unwrap(); + }; - srv.into_future().unwrap() - }) - .and_then(|(reqstream, srv)| { - let (req, mut stream) = reqstream.unwrap(); - assert_eq!(req.method(), &http::Method::POST); - let body = req.into_parts().1; - - let body = body.concat2().and_then(move |buf| { - assert!(buf.is_empty()); - - let rsp = http::Response::builder() - .status(200) - .body(()) - .unwrap(); - stream.send_response(rsp, true).unwrap(); - Ok(()) - }); - - srv.into_future() - .map(|(req, _srv)| { - assert!(req.is_none(), "unexpected request"); - }) - .drive(body) - .and_then(|(srv, ())| { - srv.expect("srv") - }) + let mut srv = Box::pin(async move { + assert!(srv.next().await.is_none(), "unexpected request"); }); + srv.drive(body).await; + srv.await; + }; - srv.join(client).wait().expect("wait"); + join(client, srv).await; } -#[test] -fn sends_reset_cancel_when_res_body_is_dropped() { +#[tokio::test] +async fn sends_reset_cancel_when_res_body_is_dropped() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); - - let client = client - .assert_server_handshake() - .unwrap() - .recv_settings() - .send_frame( - frames::headers(1) - .request("GET", "https://example.com/") - .eos() - ) - .recv_frame(frames::headers(1).response(200)) - .recv_frame(frames::reset(1).cancel()) - .send_frame( - frames::headers(3) - .request("GET", "https://example.com/") - .eos() - ) - .recv_frame(frames::headers(3).response(200)) - .recv_frame(frames::data(3, vec![0; 10])) - .recv_frame(frames::reset(3).cancel()) - .close(); + let (io, mut client) = mock::new(); - let srv = server::handshake(io).expect("handshake").and_then(|srv| { - srv.into_future().unwrap().and_then(|(reqstream, srv)| { - let (req, mut stream) = reqstream.unwrap(); + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + client.recv_frame(frames::headers(1).response(200)).await; + client.recv_frame(frames::reset(1).cancel()).await; + client + .send_frame( + frames::headers(3) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + client.recv_frame(frames::headers(3).response(200)).await; + client.recv_frame(frames::data(3, vec![0; 10])).await; + client.recv_frame(frames::reset(3).cancel()).await; + }; + + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + { + let (req, mut stream) = srv.next().await.unwrap().unwrap(); assert_eq!(req.method(), &http::Method::GET); - let rsp = http::Response::builder() - .status(200) - .body(()) - .unwrap(); + let rsp = http::Response::builder().status(200).body(()).unwrap(); stream.send_response(rsp, false).unwrap(); // SendStream dropped - - srv.into_future().unwrap() - }).and_then(|(reqstream, srv)| { - let (_req, mut stream) = reqstream.unwrap(); - - let rsp = http::Response::builder() - .status(200) - .body(()) - .unwrap(); + } + { + let (_req, mut stream) = srv.next().await.unwrap().unwrap(); + let rsp = http::Response::builder().status(200).body(()).unwrap(); let mut tx = stream.send_response(rsp, false).unwrap(); tx.send_data(vec![0; 10].into(), false).unwrap(); // no send_data with eos + } - srv.into_future().unwrap().map(|_| ()) - }) - }); + assert!(srv.next().await.is_none()); + }; - srv.join(client).wait().expect("wait"); + join(client, srv).await; } -#[test] -fn too_big_headers_sends_431() { +#[tokio::test] +async fn too_big_headers_sends_431() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); - - let client = client - .assert_server_handshake() - .unwrap() - .recv_custom_settings( - frames::settings() - .max_header_list_size(10) - ) - .send_frame( - frames::headers(1) - .request("GET", "https://example.com/") - .field("some-header", "some-value") - .eos() - ) - .recv_frame(frames::headers(1).response(431).eos()) - .idle_ms(10) - .close(); - - let srv = server::Builder::new() - .max_header_list_size(10) - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|srv| { - srv.into_future() - .expect("server") - .map(|(req, _)| { - assert!(req.is_none(), "req is {:?}", req); - }) - }); + let (io, mut client) = mock::new(); - srv.join(client).wait().expect("wait"); + let client = async move { + let settings = client.assert_server_handshake().await; + assert_frame_eq(settings, frames::settings().max_header_list_size(10)); + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .field("some-header", "some-value") + .eos(), + ) + .await; + client + .recv_frame(frames::headers(1).response(431).eos()) + .await; + idle_ms(10).await; + }; + + let srv = async move { + let mut srv = server::Builder::new() + .max_header_list_size(10) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let req = srv.next().await; + assert!(req.is_none(), "req is {:?}", req); + }; + + join(client, srv).await; } -#[test] -fn too_big_headers_sends_reset_after_431_if_not_eos() { +#[tokio::test] +async fn too_big_headers_sends_reset_after_431_if_not_eos() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); - - let client = client - .assert_server_handshake() - .unwrap() - .recv_custom_settings( - frames::settings() - .max_header_list_size(10) - ) - .send_frame( - frames::headers(1) - .request("GET", "https://example.com/") - .field("some-header", "some-value") - ) - .recv_frame(frames::headers(1).response(431).eos()) - .recv_frame(frames::reset(1).refused()) - .close(); - - let srv = server::Builder::new() - .max_header_list_size(10) - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|srv| { - srv.into_future() - .expect("server") - .map(|(req, _)| { - assert!(req.is_none(), "req is {:?}", req); - }) - }); + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_frame_eq(settings, frames::settings().max_header_list_size(10)); + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .field("some-header", "some-value"), + ) + .await; + client + .recv_frame(frames::headers(1).response(431).eos()) + .await; + client.recv_frame(frames::reset(1).refused()).await; + }; - srv.join(client).wait().expect("wait"); + let srv = async move { + let mut srv = server::Builder::new() + .max_header_list_size(10) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let req = srv.next().await; + assert!(req.is_none(), "req is {:?}", req); + }; + + join(client, srv).await; } -#[test] -fn poll_reset() { +#[tokio::test] +async fn poll_reset() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); - - let client = client - .assert_server_handshake() - .unwrap() - .recv_settings() - .send_frame( - frames::headers(1) - .request("GET", "https://example.com/") - .eos() - ) - .idle_ms(10) - .send_frame(frames::reset(1).cancel()) - .close(); - - let srv = server::Builder::new() - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|srv| { - srv.into_future() - .expect("server") - .map(|(req, conn)| { - (req.expect("request"), conn) - }) - }) - .and_then(|((_req, mut tx), conn)| { - let conn = conn.into_future() - .map(|(req, _)| assert!(req.is_none(), "no second request")) - .expect("conn"); - conn.join( - futures::future::poll_fn(move || { - tx.poll_reset() - }) - .map(|reason| { - assert_eq!(reason, Reason::CANCEL); - }) - .expect("poll_reset") + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), ) - }); + .await; + idle_ms(10).await; + client.send_frame(frames::reset(1).cancel()).await; + }; - srv.join(client).wait().expect("wait"); + let srv = async move { + let mut srv = server::Builder::new() + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + let (_req, mut tx) = srv.next().await.expect("server").unwrap(); + let conn = async move { + let req = srv.next().await; + assert!(req.is_none(), "no second request"); + }; + join(conn, async move { + let reason = poll_fn(move |cx| tx.poll_reset(cx)) + .await + .expect("poll_reset"); + assert_eq!(reason, Reason::CANCEL); + }) + .await; + }; + join(client, srv).await; } -#[test] -fn poll_reset_io_error() { +#[tokio::test] +async fn poll_reset_io_error() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); - - let client = client - .assert_server_handshake() - .unwrap() - .recv_settings() - .send_frame( - frames::headers(1) - .request("GET", "https://example.com/") - .eos() - ) - .idle_ms(10) - .close(); - - let srv = server::Builder::new() - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|srv| { - srv.into_future() - .expect("server") - .map(|(req, conn)| { - (req.expect("request"), conn) - }) - }) - .and_then(|((_req, mut tx), conn)| { - let conn = conn.into_future() - .map(|(req, _)| assert!(req.is_none(), "no second request")) - .expect("conn"); - conn.join( - futures::future::poll_fn(move || { - tx.poll_reset() - }) - .expect_err("poll_reset should error") + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), ) - }); + .await; + idle_ms(10).await; + }; + + let srv = async move { + let mut srv = server::Builder::new() + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let (_req, mut tx) = srv.next().await.expect("server").unwrap(); + let conn = async move { + let req = srv.next().await; + assert!(req.is_none(), "no second request"); + }; + join(conn, async move { + poll_fn(move |cx| tx.poll_reset(cx)) + .await + .expect_err("poll_reset should error") + }) + .await; + }; - srv.join(client).wait().expect("wait"); + join(client, srv).await; } -#[test] -fn poll_reset_after_send_response_is_user_error() { +#[tokio::test] +async fn poll_reset_after_send_response_is_user_error() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); - - let client = client - .assert_server_handshake() - .unwrap() - .recv_settings() - .send_frame( - frames::headers(1) - .request("GET", "https://example.com/") - .eos() - ) - .recv_frame( - frames::headers(1) - .response(200) - ) - .recv_frame( - // After the error, our server will drop the handles, - // meaning we receive a RST_STREAM here. - frames::reset(1).cancel() - ) - .idle_ms(10) - .close(); - - let srv = server::Builder::new() - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|srv| { - srv.into_future() - .expect("server") - .map(|(req, conn)| { - (req.expect("request"), conn) - }) - }) - .and_then(|((_req, mut tx), conn)| { - let conn = conn.into_future() - .map(|(req, _)| assert!(req.is_none(), "no second request")) - .expect("conn"); - tx.send_response(Response::new(()), false).expect("response"); - conn.join( - futures::future::poll_fn(move || { - tx.poll_reset() - }) - .expect_err("poll_reset should error") + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), ) - }); + .await; + client.recv_frame(frames::headers(1).response(200)).await; + client + .recv_frame( + // After the error, our server will drop the handles, + // meaning we receive a RST_STREAM here. + frames::reset(1).cancel(), + ) + .await; + idle_ms(10).await; + }; - srv.join(client).wait().expect("wait"); -} + let srv = async move { + let mut srv = server::Builder::new() + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let (_req, mut tx) = srv.next().await.expect("server").expect("request"); + let conn = async move { + let req = srv.next().await; + assert!(req.is_none(), "no second request"); + }; + tx.send_response(Response::new(()), false) + .expect("response"); + drop(_req); + join( + async { + poll_fn(move |cx| tx.poll_reset(cx)) + .await + .expect_err("poll_reset should error") + }, + conn, + ) + .await; + }; -#[test] -fn server_error_on_unclean_shutdown() { - use std::io::Write; + join(client, srv).await; +} +#[tokio::test] +async fn server_error_on_unclean_shutdown() { let _ = env_logger::try_init(); let (io, mut client) = mock::new(); - let srv = server::Builder::new() - .handshake::<_, Bytes>(io); + let srv = server::Builder::new().handshake::<_, Bytes>(io); - client.write_all(b"PRI *").expect("write"); + client.write_all(b"PRI *").await.expect("write"); drop(client); - srv.wait().expect_err("should error"); + srv.await.expect_err("should error"); } -#[test] -fn request_without_authority() { +#[tokio::test] +async fn request_without_authority() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); - - let client = client - .assert_server_handshake() - .unwrap() - .recv_settings() - .send_frame( - frames::headers(1) - .request("GET", "/just-a-path") - .scheme("http") - .eos() - ) - .recv_frame(frames::headers(1).response(200).eos()) - .close(); + let (io, mut client) = mock::new(); - let srv = server::handshake(io).expect("handshake").and_then(|srv| { - srv.into_future().unwrap().and_then(|(reqstream, srv)| { - let (req, mut stream) = reqstream.unwrap(); + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame( + frames::headers(1) + .request("GET", "/just-a-path") + .scheme("http") + .eos(), + ) + .await; + client + .recv_frame(frames::headers(1).response(200).eos()) + .await; + }; - assert_eq!(req.uri().path(), "/just-a-path"); + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + let (req, mut stream) = srv.next().await.unwrap().unwrap(); + assert_eq!(req.uri().path(), "/just-a-path"); - let rsp = Response::new(()); - stream.send_response(rsp, true).unwrap(); + let rsp = Response::new(()); + stream.send_response(rsp, true).unwrap(); - srv.into_future().unwrap().map(|_| ()) - }) - }); + assert!(srv.next().await.is_none()); + }; - srv.join(client).wait().expect("wait"); + join(client, srv).await; } diff --git a/tests/h2-tests/tests/stream_states.rs b/tests/h2-tests/tests/stream_states.rs index 9ce89a100..0b298f1aa 100644 --- a/tests/h2-tests/tests/stream_states.rs +++ b/tests/h2-tests/tests/stream_states.rs @@ -1,24 +1,29 @@ +#![feature(async_await)] #![deny(warnings)] +use futures::future::{join, join3, lazy, try_join}; +use futures::{FutureExt, StreamExt, TryStreamExt}; use h2_support::prelude::*; +use h2_support::util::yield_once; +use std::task::Poll; -#[test] -fn send_recv_headers_only() { +#[tokio::test] +async fn send_recv_headers_only() { let _ = env_logger::try_init(); let mock = mock_io::Builder::new() .handshake() // Write GET / .write(&[ - 0, 0, 0x10, 1, 5, 0, 0, 0, 1, 0x82, 0x87, 0x41, 0x8B, 0x9D, 0x29, - 0xAC, 0x4B, 0x8F, 0xA8, 0xE9, 0x19, 0x97, 0x21, 0xE9, 0x84, + 0, 0, 0x10, 1, 5, 0, 0, 0, 1, 0x82, 0x87, 0x41, 0x8B, 0x9D, 0x29, 0xAC, 0x4B, 0x8F, + 0xA8, 0xE9, 0x19, 0x97, 0x21, 0xE9, 0x84, ]) .write(frames::SETTINGS_ACK) // Read response .read(&[0, 0, 1, 1, 5, 0, 0, 0, 1, 0x89]) .build(); - let (mut client, mut h2) = client::handshake(mock).wait().unwrap(); + let (mut client, mut h2) = client::handshake(mock).await.unwrap(); // Send the request let request = Request::builder() @@ -29,22 +34,22 @@ fn send_recv_headers_only() { log::info!("sending request"); let (response, _) = client.send_request(request, true).unwrap(); - let resp = h2.run(response).unwrap(); + let resp = h2.run(response).await.unwrap(); assert_eq!(resp.status(), StatusCode::NO_CONTENT); - h2.wait().unwrap(); + h2.await.unwrap(); } -#[test] -fn send_recv_data() { +#[tokio::test] +async fn send_recv_data() { let _ = env_logger::try_init(); let mock = mock_io::Builder::new() .handshake() .write(&[ // POST / - 0, 0, 16, 1, 4, 0, 0, 0, 1, 131, 135, 65, 139, 157, 41, - 172, 75, 143, 168, 233, 25, 151, 33, 233, 132, + 0, 0, 16, 1, 4, 0, 0, 0, 1, 131, 135, 65, 139, 157, 41, 172, 75, 143, 168, 233, 25, 151, + 33, 233, 132, ]) .write(&[ // DATA @@ -54,13 +59,12 @@ fn send_recv_data() { // Read response .read(&[ // HEADERS - 0, 0, 1, 1, 4, 0, 0, 0, 1, 136, - // DATA - 0, 0, 5, 0, 1, 0, 0, 0, 1, 119, 111, 114, 108, 100 + 0, 0, 1, 1, 4, 0, 0, 0, 1, 136, // DATA + 0, 0, 5, 0, 1, 0, 0, 0, 1, 119, 111, 114, 108, 100, ]) .build(); - let (mut client, mut h2) = client::Builder::new().handshake(mock).wait().unwrap(); + let (mut client, mut h2) = client::Builder::new().handshake(mock).await.unwrap(); let request = Request::builder() .method(Method::POST) @@ -80,14 +84,14 @@ fn send_recv_data() { stream.send_data("hello", true).unwrap(); // Get the response - let resp = h2.run(response).unwrap(); + let resp = h2.run(response).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); // Take the body let (_, body) = resp.into_parts(); // Wait for all the data frames to be received - let bytes = h2.run(body.collect()).unwrap(); + let bytes: Vec<_> = h2.run(body.try_collect()).await.unwrap(); // One byte chunk assert_eq!(1, bytes.len()); @@ -95,29 +99,29 @@ fn send_recv_data() { assert_eq!(bytes[0], &b"world"[..]); // The H2 connection is closed - h2.wait().unwrap(); + h2.await.unwrap(); } -#[test] -fn send_headers_recv_data_single_frame() { +#[tokio::test] +async fn send_headers_recv_data_single_frame() { let _ = env_logger::try_init(); let mock = mock_io::Builder::new() .handshake() // Write GET / .write(&[ - 0, 0, 16, 1, 5, 0, 0, 0, 1, 130, 135, 65, 139, 157, 41, 172, 75, - 143, 168, 233, 25, 151, 33, 233, 132 + 0, 0, 16, 1, 5, 0, 0, 0, 1, 130, 135, 65, 139, 157, 41, 172, 75, 143, 168, 233, 25, + 151, 33, 233, 132, ]) .write(frames::SETTINGS_ACK) // Read response .read(&[ - 0, 0, 1, 1, 4, 0, 0, 0, 1, 136, 0, 0, 5, 0, 0, 0, 0, 0, 1, 104, 101, - 108, 108, 111, 0, 0, 5, 0, 1, 0, 0, 0, 1, 119, 111, 114, 108, 100, + 0, 0, 1, 1, 4, 0, 0, 0, 1, 136, 0, 0, 5, 0, 0, 0, 0, 0, 1, 104, 101, 108, 108, 111, 0, + 0, 5, 0, 1, 0, 0, 0, 1, 119, 111, 114, 108, 100, ]) .build(); - let (mut client, mut h2) = client::handshake(mock).wait().unwrap(); + let (mut client, mut h2) = client::handshake(mock).await.unwrap(); // Send the request let request = Request::builder() @@ -128,14 +132,14 @@ fn send_headers_recv_data_single_frame() { log::info!("sending request"); let (response, _) = client.send_request(request, true).unwrap(); - let resp = h2.run(response).unwrap(); + let resp = h2.run(response).await.unwrap(); assert_eq!(resp.status(), StatusCode::OK); // Take the body let (_, body) = resp.into_parts(); // Wait for all the data frames to be received - let bytes = h2.run(body.collect()).unwrap(); + let bytes: Vec<_> = h2.run(body.try_collect()).await.unwrap(); // Two data frames assert_eq!(2, bytes.len()); @@ -144,208 +148,192 @@ fn send_headers_recv_data_single_frame() { assert_eq!(bytes[1], &b"world"[..]); // The H2 connection is closed - h2.wait().unwrap(); + h2.await.unwrap(); } -#[test] -fn closed_streams_are_released() { +#[tokio::test] +async fn closed_streams_are_released() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let h2 = client::handshake(io).unwrap().and_then(|(mut client, h2)| { + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); let request = Request::get("https://example.com/").body(()).unwrap(); // Send request let (response, _) = client.send_request(request, true).unwrap(); - h2.drive(response).and_then(move |(_, response)| { - assert_eq!(response.status(), StatusCode::NO_CONTENT); - - // There are no active streams - assert_eq!(0, client.num_active_streams()); + let response = h2.drive(response).await.unwrap(); + assert_eq!(response.status(), StatusCode::NO_CONTENT); - // The response contains a handle for the body. This keeps the - // stream wired. - assert_eq!(1, client.num_wired_streams()); + // There are no active streams + assert_eq!(0, client.num_active_streams()); - let (_, body) = response.into_parts(); - assert!(body.is_end_stream()); - drop(body); + // The response contains a handle for the body. This keeps the + // stream wired. + assert_eq!(1, client.num_wired_streams()); - // The stream state is now free - assert_eq!(0, client.num_wired_streams()); + let (_, body) = response.into_parts(); + assert!(body.is_end_stream()); + drop(body); - Ok(()) - }) - }); + // The stream state is now free + assert_eq!(0, client.num_wired_streams()); + }; - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://example.com/") .eos(), ) - .send_frame(frames::headers(1).response(204).eos()) - .close(); - - let _ = h2.join(srv).wait().unwrap(); + .await; + srv.send_frame(frames::headers(1).response(204).eos()).await; + }; + join(srv, h2).await; } -#[test] -fn errors_if_recv_frame_exceeds_max_frame_size() { +#[tokio::test] +async fn errors_if_recv_frame_exceeds_max_frame_size() { let _ = env_logger::try_init(); let (io, mut srv) = mock::new(); - let h2 = client::handshake(io).unwrap().and_then(|(mut client, h2)| { - let req = client - .get("https://example.com/") - .expect("response") - .and_then(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - let body = resp.into_parts().1; - body.concat2().then(|res| { - let err = res.unwrap_err(); - assert_eq!(err.to_string(), "protocol error: frame with invalid size"); - Ok::<(), ()>(()) - }) - }); + let h2 = async move { + let (mut client, h2) = client::handshake(io).await.unwrap(); + let req = async move { + let resp = client.get("https://example.com/").await.expect("response"); + assert_eq!(resp.status(), StatusCode::OK); + let body = resp.into_parts().1; + let res = body.try_concat().await; + let err = res.unwrap_err(); + assert_eq!(err.to_string(), "protocol error: frame with invalid size"); + }; // client should see a conn error - let conn = h2.then(|res| { - let err = res.unwrap_err(); + let conn = async move { + let err = h2.await.unwrap_err(); assert_eq!(err.to_string(), "protocol error: frame with invalid size"); - Ok::<(), ()>(()) - }); - conn.unwrap().join(req) - }); + }; + join(conn, req).await; + }; // a bad peer srv.codec_mut().set_max_send_frame_size(16_384 * 4); - let srv = srv.assert_client_handshake() - .unwrap() - .ignore_settings() - .recv_frame( + let srv = async move { + let _ = srv.assert_client_handshake().await; + srv.recv_frame( frames::headers(1) .request("GET", "https://example.com/") .eos(), ) - .send_frame(frames::headers(1).response(200)) - .send_frame(frames::data(1, vec![0; 16_385]).eos()) - .recv_frame(frames::go_away(0).frame_size()) - .close(); + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.send_frame(frames::data(1, vec![0; 16_385]).eos()).await; + srv.recv_frame(frames::go_away(0).frame_size()).await; + }; - let _ = h2.join(srv).wait().unwrap(); + join(srv, h2).await; } - -#[test] -fn configure_max_frame_size() { +#[tokio::test] +async fn configure_max_frame_size() { let _ = env_logger::try_init(); let (io, mut srv) = mock::new(); - let h2 = client::Builder::new() - .max_frame_size(16_384 * 2) - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|(mut client, h2)| { - let req = client - .get("https://example.com/") - .expect("response") - .and_then(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - let body = resp.into_parts().1; - body.concat2().expect("body") - }) - .and_then(|buf| { - assert_eq!(buf.len(), 16_385); - Ok(()) - }); - - h2.expect("client").join(req) - }); - + let h2 = async move { + let (mut client, h2) = client::Builder::new() + .max_frame_size(16_384 * 2) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let req = async move { + let resp = client.get("https://example.com/").await.expect("response"); + assert_eq!(resp.status(), StatusCode::OK); + let body = resp.into_parts().1; + let buf = body.try_concat().await.expect("body"); + assert_eq!(buf.len(), 16_385); + }; + + join(async move { h2.await.expect("client") }, req).await; + }; // a good peer srv.codec_mut().set_max_send_frame_size(16_384 * 2); - let srv = srv.assert_client_handshake() - .unwrap() - .ignore_settings() - .recv_frame( + let srv = async move { + let _ = srv.assert_client_handshake().await; + srv.recv_frame( frames::headers(1) .request("GET", "https://example.com/") .eos(), ) - .send_frame(frames::headers(1).response(200)) - .send_frame(frames::data(1, vec![0; 16_385]).eos()) - .close(); - - let _ = h2.join(srv).wait().expect("wait"); + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.send_frame(frames::data(1, vec![0; 16_385]).eos()).await; + }; + join(srv, h2).await; } -#[test] -fn recv_goaway_finishes_processed_streams() { +#[tokio::test] +async fn recv_goaway_finishes_processed_streams() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://example.com/") .eos(), ) - .recv_frame( + .await; + srv.recv_frame( frames::headers(3) .request("GET", "https://example.com/") .eos(), ) - .send_frame(frames::go_away(1)) - .send_frame(frames::headers(1).response(200)) - .send_frame(frames::data(1, vec![0; 16_384]).eos()) + .await; + srv.send_frame(frames::go_away(1)).await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.send_frame(frames::data(1, vec![0; 16_384]).eos()).await; // expecting a goaway of 0, since server never initiated a stream - .recv_frame(frames::go_away(0)); + srv.recv_frame(frames::go_away(0)).await; //.close(); + }; - let h2 = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, h2)| { - let req1 = client + let h2 = async move { + let (mut client, h2) = client::handshake(io).await.expect("handshake"); + let mut client_clone = client.clone(); + let req1 = async move { + let resp = client_clone .get("https://example.com") - .expect("response") - .and_then(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - let body = resp.into_parts().1; - body.concat2().expect("body") - }) - .and_then(|buf| { - assert_eq!(buf.len(), 16_384); - Ok(()) - }); - - // this request will trigger a goaway - let req2 = client - .get("https://example.com/") - .then(|res| { - let err = res.unwrap_err(); - assert_eq!(err.to_string(), "protocol error: not a result of an error"); - Ok::<(), ()>(()) - }); - - h2.expect("client").join3(req1, req2) - }); - - - h2.join(srv).wait().expect("wait"); + .await + .expect("response"); + assert_eq!(resp.status(), StatusCode::OK); + let body = resp.into_parts().1; + let buf = body.try_concat().await.expect("body"); + assert_eq!(buf.len(), 16_384); + }; + + // this request will trigger a goaway + let req2 = async move { + let err = client.get("https://example.com/").await.unwrap_err(); + assert_eq!(err.to_string(), "protocol error: not a result of an error"); + }; + + join3(async move { h2.await.expect("client") }, req1, req2).await; + }; + + join(srv, h2).await; } -#[test] -fn recv_next_stream_id_updated_by_malformed_headers() { +#[tokio::test] +async fn recv_next_stream_id_updated_by_malformed_headers() { let _ = env_logger::try_init(); - let (io, client) = mock::new(); - + let (io, mut client) = mock::new(); let bad_auth = util::byte_str("not:a/good authority"); let mut bad_headers: frame::Headers = frames::headers(1) @@ -354,317 +342,318 @@ fn recv_next_stream_id_updated_by_malformed_headers() { .into(); bad_headers.pseudo_mut().authority = Some(bad_auth); - let client = client - .assert_server_handshake() - .unwrap() - .recv_settings() + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); // bad headers -- should error. - .send_frame(bad_headers) - .recv_frame(frames::reset(1).protocol_error()) + client.send_frame(bad_headers).await; + client.recv_frame(frames::reset(1).protocol_error()).await; // this frame is good, but the stream id should already have been incr'd - .send_frame(frames::headers(1) - .request("GET", "https://example.com/") - .eos()) - .recv_frame(frames::go_away(1).protocol_error()) - .close(); - - let srv = server::handshake(io) - .expect("handshake") - .and_then(|srv| srv.into_future().then(|res| { - let (err, _) = res.unwrap_err(); - assert_eq!(err.reason(), Some(h2::Reason::PROTOCOL_ERROR)); - Ok::<(), ()>(()) - })); - - srv.join(client).wait().expect("wait"); + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + client.recv_frame(frames::go_away(1).protocol_error()).await; + }; + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + let res = srv.next().await.unwrap(); + let err = res.unwrap_err(); + assert_eq!(err.reason(), Some(h2::Reason::PROTOCOL_ERROR)); + }; + + join(srv, client).await; } -#[test] -fn skipped_stream_ids_are_implicitly_closed() { +#[tokio::test] +async fn skipped_stream_ids_are_implicitly_closed() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); - - let srv = srv - .assert_client_handshake() - .expect("handshake") - .recv_settings() - .recv_frame(frames::headers(5) - .request("GET", "https://example.com/") - .eos(), + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( + frames::headers(5) + .request("GET", "https://example.com/") + .eos(), ) + .await; // send the response on a lower-numbered stream, which should be // implicitly closed. - .send_frame(frames::headers(3).response(299)) + srv.send_frame(frames::headers(3).response(299)).await; // however, our client choose to send a RST_STREAM because it // can't tell if it had previously reset '3'. - .recv_frame(frames::reset(3).stream_closed()) - .send_frame(frames::headers(5).response(200).eos()); - - let h2 = client::Builder::new() - .initial_stream_id(5) - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|(mut client, h2)| { - let req = client - .get("https://example.com/") - .expect("response") - .map(|res| { - assert_eq!(res.status(), StatusCode::OK); - }); - h2.drive(req) - .and_then(|(conn, ())| conn.expect("client")) - }); - - h2.join(srv).wait().expect("wait"); + srv.recv_frame(frames::reset(3).stream_closed()).await; + srv.send_frame(frames::headers(5).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::Builder::new() + .initial_stream_id(5) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let req = async move { + let res = client.get("https://example.com/").await.expect("response"); + assert_eq!(res.status(), StatusCode::OK); + }; + h2.drive(req).await; + h2.await.expect("client"); + }; + + join(srv, h2).await; } -#[test] -fn send_rst_stream_allows_recv_data() { +#[tokio::test] +async fn send_rst_stream_allows_recv_data() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://example.com/") .eos(), ) - .send_frame(frames::headers(1).response(200)) - .recv_frame(frames::reset(1).cancel()) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.recv_frame(frames::reset(1).cancel()).await; // sending frames after canceled! // note: sending 2 to cosume 50% of connection window - .send_frame(frames::data(1, vec![0; 16_384])) - .send_frame(frames::data(1, vec![0; 16_384]).eos()) + srv.send_frame(frames::data(1, vec![0; 16_384])).await; + srv.send_frame(frames::data(1, vec![0; 16_384]).eos()).await; // make sure we automatically free the connection window - .recv_frame(frames::window_update(0, 16_384 * 2)) + srv.recv_frame(frames::window_update(0, 16_384 * 2)).await; // do a pingpong to ensure no other frames were sent - .ping_pong([1; 8]) - .close(); - - let client = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let req = client - .get("https://example.com/") - .expect("response") - .and_then(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - // drop resp will send a reset - Ok(()) - }); - - conn.expect("client") - .drive(req) - .and_then(move |(conn, _)| conn.map(move |()| drop(client))) + srv.ping_pong([1; 8]).await; + }; + + let client = async move { + let (mut client, conn) = client::handshake(io).await.expect("handshake"); + let req = async { + let resp = client.get("https://example.com/").await.expect("response"); + assert_eq!(resp.status(), StatusCode::OK); + // drop resp will send a reset + }; + + let mut conn = Box::pin(async move { + conn.await.expect("client"); }); + conn.drive(req).await; + conn.await; + drop(client); + }; - - client.join(srv).wait().expect("wait"); + join(srv, client).await; } -#[test] -fn send_rst_stream_allows_recv_trailers() { +#[tokio::test] +async fn send_rst_stream_allows_recv_trailers() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://example.com/") .eos(), ) - .send_frame(frames::headers(1).response(200)) - .send_frame(frames::data(1, vec![0; 16_384])) - .recv_frame(frames::reset(1).cancel()) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.send_frame(frames::data(1, vec![0; 16_384])).await; + srv.recv_frame(frames::reset(1).cancel()).await; // sending frames after canceled! - .send_frame(frames::headers(1).field("foo", "bar").eos()) + srv.send_frame(frames::headers(1).field("foo", "bar").eos()) + .await; // do a pingpong to ensure no other frames were sent - .ping_pong([1; 8]) - .close(); - - let client = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let req = client - .get("https://example.com/") - .expect("response") - .map(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - // drop resp will send a reset - }); - - conn.expect("client") - .drive(req) - .and_then(move |(conn, _)| conn.map(move |()| drop(client))) - }); - - - client.join(srv).wait().expect("wait"); + srv.ping_pong([1; 8]).await; + }; + + let client = async move { + let (mut client, conn) = client::handshake(io).await.expect("handshake"); + let req = async { + let resp = client.get("https://example.com/").await.expect("response"); + assert_eq!(resp.status(), StatusCode::OK); + // drop resp will send a reset + }; + + let mut conn = Box::pin(async move { conn.await.expect("client") }); + conn.drive(req).await; + conn.await; + drop(client); + }; + + join(srv, client).await; } -#[test] -fn rst_stream_expires() { +#[tokio::test] +async fn rst_stream_expires() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://example.com/") .eos(), ) - .send_frame(frames::headers(1).response(200)) - .send_frame(frames::data(1, vec![0; 16_384])) - .recv_frame(frames::reset(1).cancel()) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.send_frame(frames::data(1, vec![0; 16_384])).await; + srv.recv_frame(frames::reset(1).cancel()).await; // wait till after the configured duration - .idle_ms(15) - .ping_pong([1; 8]) + idle_ms(15).await; + srv.ping_pong([1; 8]).await; // sending frame after canceled! - .send_frame(frames::data(1, vec![0; 16_384]).eos()) + srv.send_frame(frames::data(1, vec![0; 16_384]).eos()).await; // window capacity is returned - .recv_frame(frames::window_update(0, 16_384 * 2)) + srv.recv_frame(frames::window_update(0, 16_384 * 2)).await; // and then stream error - .recv_frame(frames::reset(1).stream_closed()) - .close(); - - let client = client::Builder::new() - .reset_stream_duration(Duration::from_millis(10)) - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let req = client - .get("https://example.com/") - .expect("response") - .map(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - // drop resp will send a reset - }); - - // no connection error should happen - conn.expect("client") - .drive(req) - .and_then(move |(conn, _)| conn.map(move |()| drop(client))) - }); - - client.join(srv).wait().expect("wait"); + srv.recv_frame(frames::reset(1).stream_closed()).await; + }; + + let client = async move { + let (mut client, conn) = client::Builder::new() + .reset_stream_duration(Duration::from_millis(10)) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let req = async { + let resp = client.get("https://example.com/").await.expect("response"); + assert_eq!(resp.status(), StatusCode::OK); + // drop resp will send a reset + }; + + // no connection error should happen + let mut conn = Box::pin(async move { conn.await.expect("client") }); + conn.drive(req).await; + conn.await; + drop(client); + }; + + join(srv, client).await; } -#[test] -fn rst_stream_max() { +#[tokio::test] +async fn rst_stream_max() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://example.com/") .eos(), ) - .recv_frame( + .await; + srv.recv_frame( frames::headers(3) .request("GET", "https://example.com/") .eos(), ) - .send_frame(frames::headers(1).response(200)) - .send_frame(frames::data(1, vec![0; 16])) - .send_frame(frames::headers(3).response(200)) - .send_frame(frames::data(3, vec![0; 16])) - .recv_frame(frames::reset(1).cancel()) - .recv_frame(frames::reset(3).cancel()) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.send_frame(frames::data(1, vec![0; 16])).await; + srv.send_frame(frames::headers(3).response(200)).await; + srv.send_frame(frames::data(3, vec![0; 16])).await; + srv.recv_frame(frames::reset(1).cancel()).await; + srv.recv_frame(frames::reset(3).cancel()).await; // sending frame after canceled! // newer streams trump older streams // 3 is still being ignored - .send_frame(frames::data(3, vec![0; 16]).eos()) + srv.send_frame(frames::data(3, vec![0; 16]).eos()).await; // ping pong to be sure of no goaway - .ping_pong([1; 8]) + srv.ping_pong([1; 8]).await; // 1 has been evicted, will get a reset - .send_frame(frames::data(1, vec![0; 16]).eos()) - .recv_frame(frames::reset(1).stream_closed()) - .close(); - - let client = client::Builder::new() - .max_concurrent_reset_streams(1) - .handshake::<_, Bytes>(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let req1 = client - .get("https://example.com/") - .expect("response1") - .map(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - // drop resp will send a reset - }); - - let req2 = client + srv.send_frame(frames::data(1, vec![0; 16]).eos()).await; + srv.recv_frame(frames::reset(1).stream_closed()).await; + }; + + let client = async move { + let (mut client, conn) = client::Builder::new() + .max_concurrent_reset_streams(1) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + let mut client_clone = client.clone(); + let req1 = async move { + let resp = client_clone .get("https://example.com/") - .expect("response2") - .map(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - // drop resp will send a reset - }); - - // no connection error should happen - conn.expect("client") - .drive(req1.join(req2)) - .and_then(move |(conn, _)| conn.map(move |()| drop(client))) + .await + .expect("response1"); + assert_eq!(resp.status(), StatusCode::OK); + // drop resp will send a reset + }; + + let req2 = async { + let resp = client.get("https://example.com/").await.expect("response2"); + assert_eq!(resp.status(), StatusCode::OK); + // drop resp will send a reset + }; + + // no connection error should happen + let mut conn = Box::pin(async move { + conn.await.expect("client"); }); + conn.drive(join(req1, req2)).await; + conn.await; + drop(client); + }; - - client.join(srv).wait().expect("wait"); + join(srv, client).await; } -#[test] -fn reserved_state_recv_window_update() { +#[tokio::test] +async fn reserved_state_recv_window_update() { let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(1) .request("GET", "https://example.com/") .eos(), ) - .send_frame( - frames::push_promise(1, 2) - .request("GET", "https://example.com/push") - ) + .await; + srv.send_frame(frames::push_promise(1, 2).request("GET", "https://example.com/push")) + .await; // it'd be weird to send a window update on a push promise, // since the client can't send us data, but whatever. The // point is that it's allowed, so we're testing it. - .send_frame(frames::window_update(2, 128)) - .send_frame(frames::headers(1).response(200).eos()) + srv.send_frame(frames::window_update(2, 128)).await; + srv.send_frame(frames::headers(1).response(200).eos()).await; // ping pong to ensure no goaway - .ping_pong([1; 8]) - .close(); - - let client = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let req = client - .get("https://example.com/") - .expect("response") - .and_then(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - Ok(()) - }); + srv.ping_pong([1; 8]).await; + }; + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.expect("handshake"); + let req = async move { + let resp = client.get("https://example.com/").await.expect("response"); + assert_eq!(resp.status(), StatusCode::OK); + }; - conn.drive(req) - .and_then(|(conn, _)| conn.expect("client")) - }); + conn.drive(req).await; + conn.await.expect("client"); + }; - - client.join(srv).wait().expect("wait"); + join(srv, client).await; } /* #[test] @@ -682,7 +671,7 @@ fn send_data_after_headers_eos() { .build(); let h2 = client::handshake(mock) - .wait().expect("handshake"); + .await.expect("handshake"); // Send the request let mut request = request::Head::default(); @@ -690,12 +679,12 @@ fn send_data_after_headers_eos() { request.uri = "https://http2.akamai.com/".parse().unwrap(); let id = 1.into(); - let h2 = h2.send_request(id, request, true).wait().expect("send request"); + let h2 = h2.send_request(id, request, true).await.expect("send request"); let body = "hello"; // Send the data - let err = h2.send_data(id, body.into(), true).wait().unwrap_err(); + let err = h2.send_data(id, body.into(), true).await.unwrap_err(); assert_user_err!(err, UnexpectedFrameType); } @@ -705,79 +694,68 @@ fn exceed_max_streams() { } */ - -#[test] -fn rst_while_closing() { +#[tokio::test] +async fn rst_while_closing() { // Test to reproduce panic in issue #246 --- receipt of a RST_STREAM frame // on a stream in the Half Closed (remote) state with a queued EOS causes // a panic. let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); // Rendevous when we've queued a trailers frame - let (tx, rx) = crate::futures::sync::oneshot::channel(); - - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( - frames::headers(1) - .request("POST", "https://example.com/") - ) - .send_frame(frames::headers(1).response(200)) - .send_frame(frames::headers(1).eos()) + let (tx, rx) = crate::futures::channel::oneshot::channel(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.send_frame(frames::headers(1).eos()).await; // Idling for a moment here is necessary to ensure that the client // enqueues its TRAILERS frame *before* we send the RST_STREAM frame // which causes the panic. - .wait_for(rx) + rx.await.unwrap(); // Send the RST_STREAM frame which causes the client to panic. - .send_frame(frames::reset(1).cancel()) - .ping_pong([1; 8]) - .recv_frame(frames::go_away(0).no_error()) - .close(); - ; - - let client = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let request = Request::builder() - .method(Method::POST) - .uri("https://example.com/") - .body(()) - .unwrap(); - - // The request should be left streaming. - let (resp, stream) = client.send_request(request, false) - .expect("send_request"); - let req = resp - // on receipt of an EOS response from the server, transition - // the stream Open => Half Closed (remote). - .expect("response"); - conn.drive(req) - .map(move |(conn, resp)| { - assert_eq!(resp.status(), StatusCode::OK); - (conn, stream) - }) - }) - .and_then(|(conn, mut stream)| { - // Enqueue trailers frame. - let _ = stream.send_trailers(HeaderMap::new()); - // Signal the server mock to send RST_FRAME - let _ = tx.send(()); - - conn - // yield once to allow the server mock to be polled - // before the conn flushes its buffer - .yield_once() - .expect("client") - }); - - - client.join(srv).wait().expect("wait"); + srv.send_frame(frames::reset(1).cancel()).await; + srv.ping_pong([1; 8]).await; + srv.recv_frame(frames::go_away(0).no_error()).await; + }; + + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.expect("handshake"); + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + + // The request should be left streaming. + let req = async move { + let (resp, stream) = client.send_request(request, false).expect("send_request"); + // on receipt of an EOS response from the server, transition + // the stream Open => Half Closed (remote). + let resp = resp.await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + stream + }; + let mut stream = conn.drive(req).await; + // Enqueue trailers frame. + let _ = stream.send_trailers(HeaderMap::new()); + // Signal the server mock to send RST_FRAME + let _ = tx.send(()).unwrap(); + drop(stream); + yield_once().await; + // yield once to allow the server mock to be polled + // before the conn flushes its buffer + conn.await.expect("client"); + }; + + join(srv, client).await; } -#[test] -fn rst_with_buffered_data() { +#[tokio::test] +async fn rst_with_buffered_data() { // Data is buffered in `FramedWrite` and the stream is reset locally before // the data is fully flushed. Given that resetting a stream requires // clearing all associated state for that stream, this test ensures that the @@ -785,64 +763,51 @@ fn rst_with_buffered_data() { let _ = env_logger::try_init(); // This allows the settings + headers frame through - let (io, srv) = mock::new_with_write_capacity(73); + let (io, mut srv) = mock::new_with_write_capacity(73); // Synchronize the client / server on response - let (tx, rx) = crate::futures::sync::oneshot::channel(); - - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( - frames::headers(1) - .request("POST", "https://example.com/") - ) - .buffer_bytes(128) - .send_frame(frames::headers(1).response(204).eos()) - .send_frame(frames::reset(1).cancel()) - .wait_for(rx) - .unbounded_bytes() - .recv_frame( - frames::data(1, vec![0; 16_384])) - .close() - ; + let (tx, rx) = crate::futures::channel::oneshot::channel(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://example.com/")) + .await; + srv.buffer_bytes(128).await; + srv.send_frame(frames::headers(1).response(204).eos()).await; + srv.send_frame(frames::reset(1).cancel()).await; + rx.await.unwrap(); + srv.unbounded_bytes().await; + srv.recv_frame(frames::data(1, vec![0; 16_384])).await; + }; // A large body - let body = vec![0; 2 * frame::DEFAULT_INITIAL_WINDOW_SIZE as usize]; + let body = vec![0u8; 2 * frame::DEFAULT_INITIAL_WINDOW_SIZE as usize]; - let client = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, conn)| { - let request = Request::builder() - .method(Method::POST) - .uri("https://example.com/") - .body(()) - .unwrap(); - - // Send the request - let (resp, mut stream) = client.send_request(request, false) - .expect("send_request"); - - // Send the data - stream.send_data(body.into(), true).unwrap(); - - conn.drive({ - resp.then(|_res| { - Ok::<_, ()>(()) - }) - }) - }) - .and_then(move |(conn, _)| { - tx.send(()).unwrap(); - conn.unwrap() - }); + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.expect("handshake"); + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + + // Send the request + let (resp, mut stream) = client.send_request(request, false).expect("send_request"); + // Send the data + stream.send_data(body.into(), true).unwrap(); - client.join(srv).wait().expect("wait"); + conn.drive(resp).await.ok(); + tx.send(()).unwrap(); + conn.await.unwrap(); + }; + + join(srv, client).await; } -#[test] -fn err_with_buffered_data() { +#[tokio::test] +async fn err_with_buffered_data() { // Data is buffered in `FramedWrite` and the stream is reset locally before // the data is fully flushed. Given that resetting a stream requires // clearing all associated state for that stream, this test ensures that the @@ -850,63 +815,54 @@ fn err_with_buffered_data() { let _ = env_logger::try_init(); // This allows the settings + headers frame through - let (io, srv) = mock::new_with_write_capacity(73); + let (io, mut srv) = mock::new_with_write_capacity(73); // Synchronize the client / server on response - let (tx, rx) = crate::futures::sync::oneshot::channel(); - - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( - frames::headers(1) - .request("POST", "https://example.com/") - ) - .buffer_bytes(128) - .send_frame(frames::headers(1).response(204).eos()) + let (tx, rx) = crate::futures::channel::oneshot::channel(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://example.com/")) + .await; + srv.buffer_bytes(128).await; + srv.send_frame(frames::headers(1).response(204).eos()).await; // Send invalid data - .send_bytes(b"\x00\x00\x00\x00\x00\x00\x00\x00\x00") - .wait_for(rx) - .unbounded_bytes() - .recv_frame( - frames::data(1, vec![0; 16_384])) - .close() - ; + srv.send_bytes(b"\x00\x00\x00\x00\x00\x00\x00\x00\x00") + .await; + rx.await.unwrap(); + srv.unbounded_bytes().await; + srv.recv_frame(frames::data(1, vec![0; 16_384])).await; + }; // A large body let body = vec![0; 2 * frame::DEFAULT_INITIAL_WINDOW_SIZE as usize]; - let client = client::handshake(io) - .then(|res| { - let (mut client, conn) = res.unwrap(); + let client = async move { + let (mut client, conn) = client::handshake(io).await.unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); - let request = Request::builder() - .method(Method::POST) - .uri("https://example.com/") - .body(()) - .unwrap(); + // Send the request + let (resp, mut stream) = client.send_request(request, false).expect("send_request"); - // Send the request - let (resp, mut stream) = client.send_request(request, false) - .expect("send_request"); + // Send the data + stream.send_data(body.into(), true).unwrap(); - // Send the data - stream.send_data(body.into(), true).unwrap(); + drop(client); + let res = try_join(conn, resp).await; + assert!(res.is_err()); + tx.send(()).unwrap(); + }; - conn.join(resp) - }) - .then(move |res| { - assert!(res.is_err()); - tx.send(()).unwrap(); - Ok(()) - }); - - - client.join(srv).wait().expect("wait"); + join(srv, client).await; } -#[test] -fn send_err_with_buffered_data() { +#[tokio::test] +async fn send_err_with_buffered_data() { // Data is buffered in `FramedWrite` and the stream is reset locally before // the data is fully flushed. Given that resetting a stream requires // clearing all associated state for that stream, this test ensures that the @@ -914,123 +870,112 @@ fn send_err_with_buffered_data() { let _ = env_logger::try_init(); // This allows the settings + headers frame through - let (io, srv) = mock::new_with_write_capacity(73); + let (io, mut srv) = mock::new_with_write_capacity(73); // Synchronize the client / server on response - let (tx, rx) = crate::futures::sync::oneshot::channel(); - - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( - frames::headers(1) - .request("POST", "https://example.com/") - ) - .buffer_bytes(128) - .send_frame(frames::headers(1).response(204).eos()) - .wait_for(rx) - .unbounded_bytes() - .recv_frame( - frames::data(1, vec![0; 16_384])) - .recv_frame(frames::reset(1).cancel()) - .recv_frame(frames::go_away(0).no_error()) - .close() - ; + let (tx, rx) = crate::futures::channel::oneshot::channel(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://example.com/")) + .await; + srv.buffer_bytes(128).await; + srv.send_frame(frames::headers(1).response(204).eos()).await; + rx.await.unwrap(); + srv.unbounded_bytes().await; + srv.recv_frame(frames::data(1, vec![0; 16_384])).await; + srv.recv_frame(frames::reset(1).cancel()).await; + srv.recv_frame(frames::go_away(0).no_error()).await; + }; // A large body let body = vec![0; 2 * frame::DEFAULT_INITIAL_WINDOW_SIZE as usize]; - let client = client::handshake(io) - .expect("handshake") - .and_then(|(mut client, mut conn)| { - let request = Request::builder() - .method(Method::POST) - .uri("https://example.com/") - .body(()) - .unwrap(); - - // Send the request - let (resp, mut stream) = client.send_request(request, false) - .expect("send_request"); - - // Send the data - stream.send_data(body.into(), true).unwrap(); - - // Hack to drive the connection, trying to flush data - crate::futures::future::lazy(|| { - conn.poll().unwrap(); - Ok::<_, ()>(()) - }).wait().unwrap(); - - // Send a reset - stream.send_reset(Reason::CANCEL); - - conn.drive({ - resp.then(|_res| { - Ok::<_, ()>(()) - }) - }) + let client = async move { + let (mut client, mut conn) = client::handshake(io).await.expect("handshake"); + let request = Request::builder() + .method(Method::POST) + .uri("https://example.com/") + .body(()) + .unwrap(); + + // Send the request + let (resp, mut stream) = client.send_request(request, false).expect("send_request"); + + // Send the data + stream.send_data(body.into(), true).unwrap(); + + // Hack to drive the connection, trying to flush data + lazy(|cx| { + if let Poll::Ready(v) = conn.poll_unpin(cx) { + v.unwrap(); + } }) - .and_then(move |(conn, _)| { - tx.send(()).unwrap(); - conn.unwrap() - }); - - - client.join(srv).wait().expect("wait"); + .await; + + // Send a reset + stream.send_reset(Reason::CANCEL); + drop(stream); + drop(client); + conn.drive(resp).await.ok(); + tx.send(()).unwrap(); + conn.await.unwrap(); + }; + + join(srv, client).await; } -#[test] -fn srv_window_update_on_lower_stream_id() { +#[tokio::test] +async fn srv_window_update_on_lower_stream_id() { // See https://github.com/hyperium/h2/issues/208 let _ = env_logger::try_init(); - let (io, srv) = mock::new(); + let (io, mut srv) = mock::new(); - let srv = srv.assert_client_handshake() - .unwrap() - .recv_settings() - .recv_frame( + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame( frames::headers(7) .request("GET", "https://example.com/") - .eos() + .eos(), + ) + .await; + srv.send_frame( + frames::push_promise(7, 2).request("GET", "https://http2.akamai.com/style.css"), ) - .send_frame(frames::push_promise(7, 2).request("GET", "https://http2.akamai.com/style.css")) - .send_frame(frames::headers(7).eos()) - .recv_frame(frames::reset(2).cancel()) - .send_frame(frames::window_update(5, 66666)) - .close() - ; - - let client = client::Builder::new() - .initial_stream_id(7) - .handshake::<_, Bytes>(io) - .unwrap() - .and_then(|(mut client, h2)| { - let request = Request::builder() - .method("GET") - .uri("https://example.com/") - .body(()).unwrap(); - - let response = client.send_request(request, true) + .await; + srv.send_frame(frames::headers(7).eos()).await; + srv.recv_frame(frames::reset(2).cancel()).await; + srv.send_frame(frames::window_update(5, 66666)).await; + }; + + let client = async move { + let (mut client, mut h2) = client::Builder::new() + .initial_stream_id(7) + .handshake::<_, Bytes>(io) + .await + .unwrap(); + let request = Request::builder() + .method("GET") + .uri("https://example.com/") + .body(()) + .unwrap(); + + let response = async { + let resp = client + .send_request(request, true) .unwrap() - .0.expect("response") - .and_then(|resp| { - assert_eq!(resp.status(), StatusCode::OK); - Ok(()) - }); - - h2.expect("client") - .drive(response) - .and_then(move |(h2, _)| { - println!("RESPONSE DONE"); - h2.map(move |()| drop(client)) - }) - .then(|result| { - println!("WUT"); - assert!(result.is_ok(), "result: {:?}", result); - Ok::<_, ()>(()) - }) - }); - - srv.join(client).wait().unwrap(); + .0 + .await + .expect("response"); + assert_eq!(resp.status(), StatusCode::OK); + }; + + h2.drive(response).await; + println!("RESPONSE DONE"); + h2.await.expect("client"); + drop(client); + }; + join(srv, client).await; } diff --git a/tests/h2-tests/tests/trailers.rs b/tests/h2-tests/tests/trailers.rs index 64c77f29a..47935342f 100644 --- a/tests/h2-tests/tests/trailers.rs +++ b/tests/h2-tests/tests/trailers.rs @@ -1,25 +1,28 @@ +#![feature(async_await)] + +use futures::StreamExt; use h2_support::prelude::*; -#[test] -fn recv_trailers_only() { +#[tokio::test] +async fn recv_trailers_only() { let _ = env_logger::try_init(); let mock = mock_io::Builder::new() .handshake() // Write GET / .write(&[ - 0, 0, 0x10, 1, 5, 0, 0, 0, 1, 0x82, 0x87, 0x41, 0x8B, 0x9D, 0x29, - 0xAC, 0x4B, 0x8F, 0xA8, 0xE9, 0x19, 0x97, 0x21, 0xE9, 0x84, + 0, 0, 0x10, 1, 5, 0, 0, 0, 1, 0x82, 0x87, 0x41, 0x8B, 0x9D, 0x29, 0xAC, 0x4B, 0x8F, + 0xA8, 0xE9, 0x19, 0x97, 0x21, 0xE9, 0x84, ]) .write(frames::SETTINGS_ACK) // Read response .read(&[ - 0, 0, 1, 1, 4, 0, 0, 0, 1, 0x88, 0, 0, 9, 1, 5, 0, 0, 0, 1, - 0x40, 0x84, 0x42, 0x46, 0x9B, 0x51, 0x82, 0x3F, 0x5F, + 0, 0, 1, 1, 4, 0, 0, 0, 1, 0x88, 0, 0, 9, 1, 5, 0, 0, 0, 1, 0x40, 0x84, 0x42, 0x46, + 0x9B, 0x51, 0x82, 0x3F, 0x5F, ]) .build(); - let (mut client, mut h2) = client::handshake(mock).wait().unwrap(); + let (mut client, mut h2) = client::handshake(mock).await.unwrap(); // Send the request let request = Request::builder() @@ -30,44 +33,47 @@ fn recv_trailers_only() { log::info!("sending request"); let (response, _) = client.send_request(request, true).unwrap(); - let response = h2.run(response).unwrap(); + let response = h2.run(response).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); let (_, mut body) = response.into_parts(); // Make sure there is no body - let chunk = h2.run(poll_fn(|| body.poll())).unwrap(); + let chunk = h2.run(Box::pin(body.next())).await; assert!(chunk.is_none()); - let trailers = h2.run(poll_fn(|| body.poll_trailers())).unwrap().unwrap(); + let trailers = h2 + .run(poll_fn(|cx| body.poll_trailers(cx))) + .await + .unwrap() + .unwrap(); assert_eq!(1, trailers.len()); assert_eq!(trailers["status"], "ok"); - h2.wait().unwrap(); + h2.await.unwrap(); } -#[test] -fn send_trailers_immediately() { +#[tokio::test] +async fn send_trailers_immediately() { let _ = env_logger::try_init(); let mock = mock_io::Builder::new() .handshake() // Write GET / .write(&[ - 0, 0, 0x10, 1, 4, 0, 0, 0, 1, 0x82, 0x87, 0x41, 0x8B, 0x9D, 0x29, - 0xAC, 0x4B, 0x8F, 0xA8, 0xE9, 0x19, 0x97, 0x21, 0xE9, 0x84, 0, 0, - 0x0A, 1, 5, 0, 0, 0, 1, 0x40, 0x83, 0xF6, 0x7A, 0x66, 0x84, 0x9C, - 0xB4, 0x50, 0x7F, + 0, 0, 0x10, 1, 4, 0, 0, 0, 1, 0x82, 0x87, 0x41, 0x8B, 0x9D, 0x29, 0xAC, 0x4B, 0x8F, + 0xA8, 0xE9, 0x19, 0x97, 0x21, 0xE9, 0x84, 0, 0, 0x0A, 1, 5, 0, 0, 0, 1, 0x40, 0x83, + 0xF6, 0x7A, 0x66, 0x84, 0x9C, 0xB4, 0x50, 0x7F, ]) .write(frames::SETTINGS_ACK) // Read response .read(&[ - 0, 0, 1, 1, 4, 0, 0, 0, 1, 0x88, 0, 0, 0x0B, 0, 1, 0, 0, 0, 1, - 0x68, 0x65, 0x6C, 0x6C, 0x6F, 0x20, 0x77, 0x6F, 0x72, 0x6C, 0x64, + 0, 0, 1, 1, 4, 0, 0, 0, 1, 0x88, 0, 0, 0x0B, 0, 1, 0, 0, 0, 1, 0x68, 0x65, 0x6C, 0x6C, + 0x6F, 0x20, 0x77, 0x6F, 0x72, 0x6C, 0x64, ]) .build(); - let (mut client, mut h2) = client::handshake(mock).wait().unwrap(); + let (mut client, mut h2) = client::handshake(mock).await.unwrap(); // Send the request let request = Request::builder() @@ -83,22 +89,21 @@ fn send_trailers_immediately() { stream.send_trailers(trailers).unwrap(); - let response = h2.run(response).unwrap(); + let response = h2.run(response).await.unwrap(); assert_eq!(response.status(), StatusCode::OK); let (_, mut body) = response.into_parts(); // There is a data chunk - let chunk = h2.run(poll_fn(|| body.poll())).unwrap(); - assert!(chunk.is_some()); + let _ = h2.run(body.next()).await.unwrap().unwrap(); - let chunk = h2.run(poll_fn(|| body.poll())).unwrap(); + let chunk = h2.run(body.next()).await; assert!(chunk.is_none()); - let trailers = h2.run(poll_fn(|| body.poll_trailers())).unwrap(); + let trailers = h2.run(poll_fn(|cx| body.poll_trailers(cx))).await; assert!(trailers.is_none()); - h2.wait().unwrap(); + h2.await.unwrap(); } #[test]