diff --git a/Cargo.toml b/Cargo.toml index 34ca3d687..e75849105 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,7 +43,7 @@ atomic-waker = "1.0.0" futures-core = { version = "0.3", default-features = false } futures-sink = { version = "0.3", default-features = false } tokio-util = { version = "0.7.1", features = ["codec", "io"] } -tokio = { version = "1", features = ["io-util"] } +tokio = { version = "1", features = ["io-util", "time"] } bytes = "1" http = "1" tracing = { version = "0.1.35", default-features = false, features = ["std"] } diff --git a/src/client.rs b/src/client.rs index ffeda6077..f6ea04dd8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -343,6 +343,8 @@ pub struct Builder { /// /// When this gets exceeded, we issue GOAWAYs. local_max_error_reset_streams: Option, + + keepalive_timeout: Option, } #[derive(Debug)] @@ -580,7 +582,6 @@ where } } -#[cfg(feature = "unstable")] impl SendRequest where B: Buf, @@ -661,6 +662,7 @@ impl Builder { initial_target_connection_window_size: None, initial_max_send_streams: usize::MAX, settings: Default::default(), + keepalive_timeout: None, stream_id: 1.into(), local_max_error_reset_streams: Some(proto::DEFAULT_LOCAL_RESET_COUNT_MAX), } @@ -996,6 +998,11 @@ impl Builder { self } + /// Sets the duration connection should be closed when there no stream. + pub fn keepalive_timeout(&mut self, dur: Duration) -> &mut Self { + self.keepalive_timeout = Some(dur); + self + } /// Sets the maximum number of local resets due to protocol errors made by the remote end. /// /// Invalid frames and many other protocol errors will lead to resets being generated for those streams. @@ -1332,6 +1339,7 @@ where max_send_buffer_size: builder.max_send_buffer_size, reset_stream_duration: builder.reset_stream_duration, reset_stream_max: builder.reset_stream_max, + keepalive_timeout: builder.keepalive_timeout, remote_reset_stream_max: builder.pending_accept_reset_stream_max, local_error_reset_streams_max: builder.local_max_error_reset_streams, settings: builder.settings.clone(), diff --git a/src/frame/reason.rs b/src/frame/reason.rs index ff5e2012f..3b3660610 100644 --- a/src/frame/reason.rs +++ b/src/frame/reason.rs @@ -58,6 +58,8 @@ impl Reason { pub const INADEQUATE_SECURITY: Reason = Reason(12); /// The endpoint requires that HTTP/1.1 be used instead of HTTP/2. pub const HTTP_1_1_REQUIRED: Reason = Reason(13); + /// The endpoint reach keepalive timeout + pub const KEEPALIVE_TIMEOUT: Reason = Reason(14); /// Get a string description of the error code. pub fn description(&self) -> &str { @@ -79,6 +81,7 @@ impl Reason { 11 => "detected excessive load generating behavior", 12 => "security properties do not meet minimum requirements", 13 => "endpoint requires HTTP/1.1", + 14 => "keepalive timeout reached", _ => "unknown reason", } } diff --git a/src/proto/connection.rs b/src/proto/connection.rs index 5589fabcb..46dbc3167 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -7,12 +7,14 @@ use crate::proto::*; use bytes::Bytes; use futures_core::Stream; +use std::future::Future; use std::io; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::AsyncRead; +use tokio::time::Sleep; /// An H2 connection #[derive(Debug)] @@ -57,6 +59,9 @@ where /// A `tracing` span tracking the lifetime of the connection. span: tracing::Span, + keepalive: Option>>, + keepalive_timeout: Option, + /// Client or server _phantom: PhantomData

, } @@ -82,6 +87,7 @@ pub(crate) struct Config { pub reset_stream_max: usize, pub remote_reset_stream_max: usize, pub local_error_reset_streams_max: Option, + pub keepalive_timeout: Option, pub settings: frame::Settings, } @@ -135,6 +141,8 @@ where ping_pong: PingPong::new(), settings: Settings::new(config.settings), streams, + keepalive: None, + keepalive_timeout: config.keepalive_timeout, span: tracing::debug_span!("Connection", peer = %P::NAME), _phantom: PhantomData, }, @@ -173,6 +181,10 @@ where pub(crate) fn max_recv_streams(&self) -> usize { self.inner.streams.max_recv_streams() } + /// Returns the number of active stream + pub(crate) fn active_streams(&self) -> usize { + self.inner.streams.num_active_streams() + } #[cfg(feature = "unstable")] pub fn num_wired_streams(&self) -> usize { @@ -263,22 +275,23 @@ where let _e = span.enter(); let span = tracing::trace_span!("poll"); let _e = span.enter(); - - loop { + 'outer: loop { tracing::trace!(connection.state = ?self.inner.state); // TODO: probably clean up this glob of code match self.inner.state { // When open, continue to poll a frame State::Open => { let result = match self.poll2(cx) { - Poll::Ready(result) => result, + Poll::Ready(result) => { + self.inner.keepalive = None; + result + } // The connection is not ready to make progress Poll::Pending => { // Ensure all window updates have been sent. // // This will also handle flushing `self.codec` ready!(self.inner.streams.poll_complete(cx, &mut self.codec))?; - if (self.inner.error.is_some() || self.inner.go_away.should_close_on_idle()) && !self.inner.streams.has_streams() @@ -286,7 +299,28 @@ where self.inner.as_dyn().go_away_now(Reason::NO_ERROR); continue; } - + if !self.inner.streams.has_streams() { + loop { + match ( + self.inner.keepalive.as_mut(), + self.inner.keepalive_timeout, + ) { + (Some(sleep), _) => { + ready!(sleep.as_mut().poll(cx)); + self.inner + .as_dyn() + .go_away_now(Reason::KEEPALIVE_TIMEOUT); + continue 'outer; + } + (None, Some(timeout)) => { + self.inner + .keepalive + .replace(Box::pin(tokio::time::sleep(timeout))); + } + _ => break, + } + } + } return Poll::Pending; } }; diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index 875b6103f..b2f7918f7 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -1008,7 +1008,6 @@ where self.inner.lock().unwrap().counts.max_recv_streams() } - #[cfg(feature = "unstable")] pub fn num_active_streams(&self) -> usize { let me = self.inner.lock().unwrap(); me.store.num_active_streams() diff --git a/src/server.rs b/src/server.rs index b00bc0866..1e74d3f88 100644 --- a/src/server.rs +++ b/src/server.rs @@ -258,6 +258,9 @@ pub struct Builder { /// /// When this gets exceeded, we issue GOAWAYs. local_max_error_reset_streams: Option, + + /// Keepalive timeout + keepalive_timeout: Option, } /// Send a response back to the client @@ -581,6 +584,15 @@ where self.connection.max_recv_streams() } + /// Returns whether has stream alive + pub fn has_streams_or_other_references(&self) -> bool { + self.connection.has_streams_or_other_references() + } + /// Returns the number of current active stream. + pub fn active_stream(&self) -> usize { + self.connection.active_streams() + } + // Could disappear at anytime. #[doc(hidden)] #[cfg(feature = "unstable")] @@ -650,7 +662,7 @@ impl Builder { settings: Settings::default(), initial_target_connection_window_size: None, max_send_buffer_size: proto::DEFAULT_MAX_SEND_BUFFER_SIZE, - + keepalive_timeout: None, local_max_error_reset_streams: Some(proto::DEFAULT_LOCAL_RESET_COUNT_MAX), } } @@ -1015,6 +1027,11 @@ impl Builder { self } + /// Sets the duration connection should be closed when there no stream. + pub fn keepalive_timeout(&mut self, dur: Duration) -> &mut Self { + self.keepalive_timeout = Some(dur); + self + } /// Enables the [extended CONNECT protocol]. /// /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 @@ -1379,6 +1396,7 @@ where initial_max_send_streams: 0, max_send_buffer_size: self.builder.max_send_buffer_size, reset_stream_duration: self.builder.reset_stream_duration, + keepalive_timeout: self.builder.keepalive_timeout, reset_stream_max: self.builder.reset_stream_max, remote_reset_stream_max: self.builder.pending_accept_reset_stream_max, local_error_reset_streams_max: self diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index c1af54198..fa5c74b15 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -69,6 +69,41 @@ async fn server_builder_set_max_concurrent_streams() { join(client, h2).await; } +#[tokio::test] +async fn server_builder_set_keepalive_timeout() { + h2_support::trace_init!(); + let (io, mut client) = mock::new(); + let h1 = async { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + client + .recv_frame(frames::headers(1).response(200).eos()) + .await; + }; + + let mut builder = server::Builder::new(); + builder.keepalive_timeout(Duration::from_secs(2)); + let h2 = async move { + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + let (req, mut stream) = srv.next().await.unwrap().unwrap(); + assert_eq!(req.method(), &http::Method::GET); + + let rsp = http::Response::builder().status(200).body(()).unwrap(); + let res = stream.send_response(rsp, true).unwrap(); + drop(res); + let r1 = srv.accept().await; + println!("rrr {r1:?}"); + assert!(r1.is_some_and(|f| f.is_err_and(|f| f.is_go_away()))); + }; + join(h1, h2).await; +} #[tokio::test] async fn serve_request() { h2_support::trace_init!();