Skip to content

feat: Add keepalive timeout when connection idle #827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -43,7 +43,7 @@ atomic-waker = "1.0.0"
futures-core = { version = "0.3", default-features = false }
futures-sink = { version = "0.3", default-features = false }
tokio-util = { version = "0.7.1", features = ["codec", "io"] }
tokio = { version = "1", features = ["io-util"] }
tokio = { version = "1", features = ["io-util", "time"] }
bytes = "1"
http = "1"
tracing = { version = "0.1.35", default-features = false, features = ["std"] }
10 changes: 9 additions & 1 deletion src/client.rs
Original file line number Diff line number Diff line change
@@ -343,6 +343,8 @@ pub struct Builder {
///
/// When this gets exceeded, we issue GOAWAYs.
local_max_error_reset_streams: Option<usize>,

keepalive_timeout: Option<Duration>,
}

#[derive(Debug)]
@@ -580,7 +582,6 @@ where
}
}

#[cfg(feature = "unstable")]
impl<B> SendRequest<B>
where
B: Buf,
@@ -661,6 +662,7 @@ impl Builder {
initial_target_connection_window_size: None,
initial_max_send_streams: usize::MAX,
settings: Default::default(),
keepalive_timeout: None,
stream_id: 1.into(),
local_max_error_reset_streams: Some(proto::DEFAULT_LOCAL_RESET_COUNT_MAX),
}
@@ -996,6 +998,11 @@ impl Builder {
self
}

/// Sets the duration connection should be closed when there no stream.
pub fn keepalive_timeout(&mut self, dur: Duration) -> &mut Self {
self.keepalive_timeout = Some(dur);
self
}
/// Sets the maximum number of local resets due to protocol errors made by the remote end.
///
/// Invalid frames and many other protocol errors will lead to resets being generated for those streams.
@@ -1332,6 +1339,7 @@ where
max_send_buffer_size: builder.max_send_buffer_size,
reset_stream_duration: builder.reset_stream_duration,
reset_stream_max: builder.reset_stream_max,
keepalive_timeout: builder.keepalive_timeout,
remote_reset_stream_max: builder.pending_accept_reset_stream_max,
local_error_reset_streams_max: builder.local_max_error_reset_streams,
settings: builder.settings.clone(),
3 changes: 3 additions & 0 deletions src/frame/reason.rs
Original file line number Diff line number Diff line change
@@ -58,6 +58,8 @@ impl Reason {
pub const INADEQUATE_SECURITY: Reason = Reason(12);
/// The endpoint requires that HTTP/1.1 be used instead of HTTP/2.
pub const HTTP_1_1_REQUIRED: Reason = Reason(13);
/// The endpoint reach keepalive timeout
pub const KEEPALIVE_TIMEOUT: Reason = Reason(14);

/// Get a string description of the error code.
pub fn description(&self) -> &str {
@@ -79,6 +81,7 @@ impl Reason {
11 => "detected excessive load generating behavior",
12 => "security properties do not meet minimum requirements",
13 => "endpoint requires HTTP/1.1",
14 => "keepalive timeout reached",
_ => "unknown reason",
}
}
44 changes: 39 additions & 5 deletions src/proto/connection.rs
Original file line number Diff line number Diff line change
@@ -7,12 +7,14 @@ use crate::proto::*;

use bytes::Bytes;
use futures_core::Stream;
use std::future::Future;
use std::io;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::AsyncRead;
use tokio::time::Sleep;

/// An H2 connection
#[derive(Debug)]
@@ -57,6 +59,9 @@ where
/// A `tracing` span tracking the lifetime of the connection.
span: tracing::Span,

keepalive: Option<Pin<Box<Sleep>>>,
keepalive_timeout: Option<Duration>,

/// Client or server
_phantom: PhantomData<P>,
}
@@ -82,6 +87,7 @@ pub(crate) struct Config {
pub reset_stream_max: usize,
pub remote_reset_stream_max: usize,
pub local_error_reset_streams_max: Option<usize>,
pub keepalive_timeout: Option<Duration>,
pub settings: frame::Settings,
}

@@ -135,6 +141,8 @@ where
ping_pong: PingPong::new(),
settings: Settings::new(config.settings),
streams,
keepalive: None,
keepalive_timeout: config.keepalive_timeout,
span: tracing::debug_span!("Connection", peer = %P::NAME),
_phantom: PhantomData,
},
@@ -173,6 +181,10 @@ where
pub(crate) fn max_recv_streams(&self) -> usize {
self.inner.streams.max_recv_streams()
}
/// Returns the number of active stream
pub(crate) fn active_streams(&self) -> usize {
self.inner.streams.num_active_streams()
}

#[cfg(feature = "unstable")]
pub fn num_wired_streams(&self) -> usize {
@@ -263,30 +275,52 @@ where
let _e = span.enter();
let span = tracing::trace_span!("poll");
let _e = span.enter();

loop {
'outer: loop {
tracing::trace!(connection.state = ?self.inner.state);
// TODO: probably clean up this glob of code
match self.inner.state {
// When open, continue to poll a frame
State::Open => {
let result = match self.poll2(cx) {
Poll::Ready(result) => result,
Poll::Ready(result) => {
self.inner.keepalive = None;
result
}
// The connection is not ready to make progress
Poll::Pending => {
// Ensure all window updates have been sent.
//
// This will also handle flushing `self.codec`
ready!(self.inner.streams.poll_complete(cx, &mut self.codec))?;

if (self.inner.error.is_some()
|| self.inner.go_away.should_close_on_idle())
&& !self.inner.streams.has_streams()
{
self.inner.as_dyn().go_away_now(Reason::NO_ERROR);
continue;
}

if !self.inner.streams.has_streams() {
loop {
match (
self.inner.keepalive.as_mut(),
self.inner.keepalive_timeout,
) {
(Some(sleep), _) => {
ready!(sleep.as_mut().poll(cx));
self.inner
.as_dyn()
.go_away_now(Reason::KEEPALIVE_TIMEOUT);
continue 'outer;
}
(None, Some(timeout)) => {
self.inner
.keepalive
.replace(Box::pin(tokio::time::sleep(timeout)));
}
_ => break,
}
}
}
return Poll::Pending;
}
};
1 change: 0 additions & 1 deletion src/proto/streams/streams.rs
Original file line number Diff line number Diff line change
@@ -1008,7 +1008,6 @@ where
self.inner.lock().unwrap().counts.max_recv_streams()
}

#[cfg(feature = "unstable")]
pub fn num_active_streams(&self) -> usize {
let me = self.inner.lock().unwrap();
me.store.num_active_streams()
20 changes: 19 additions & 1 deletion src/server.rs
Original file line number Diff line number Diff line change
@@ -258,6 +258,9 @@ pub struct Builder {
///
/// When this gets exceeded, we issue GOAWAYs.
local_max_error_reset_streams: Option<usize>,

/// Keepalive timeout
keepalive_timeout: Option<Duration>,
}

/// Send a response back to the client
@@ -581,6 +584,15 @@ where
self.connection.max_recv_streams()
}

/// Returns whether has stream alive
pub fn has_streams_or_other_references(&self) -> bool {
self.connection.has_streams_or_other_references()
}
/// Returns the number of current active stream.
pub fn active_stream(&self) -> usize {
self.connection.active_streams()
}

// Could disappear at anytime.
#[doc(hidden)]
#[cfg(feature = "unstable")]
@@ -650,7 +662,7 @@ impl Builder {
settings: Settings::default(),
initial_target_connection_window_size: None,
max_send_buffer_size: proto::DEFAULT_MAX_SEND_BUFFER_SIZE,

keepalive_timeout: None,
local_max_error_reset_streams: Some(proto::DEFAULT_LOCAL_RESET_COUNT_MAX),
}
}
@@ -1015,6 +1027,11 @@ impl Builder {
self
}

/// Sets the duration connection should be closed when there no stream.
pub fn keepalive_timeout(&mut self, dur: Duration) -> &mut Self {
self.keepalive_timeout = Some(dur);
self
}
/// Enables the [extended CONNECT protocol].
///
/// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4
@@ -1379,6 +1396,7 @@ where
initial_max_send_streams: 0,
max_send_buffer_size: self.builder.max_send_buffer_size,
reset_stream_duration: self.builder.reset_stream_duration,
keepalive_timeout: self.builder.keepalive_timeout,
reset_stream_max: self.builder.reset_stream_max,
remote_reset_stream_max: self.builder.pending_accept_reset_stream_max,
local_error_reset_streams_max: self
35 changes: 35 additions & 0 deletions tests/h2-tests/tests/server.rs
Original file line number Diff line number Diff line change
@@ -69,6 +69,41 @@ async fn server_builder_set_max_concurrent_streams() {
join(client, h2).await;
}

#[tokio::test]
async fn server_builder_set_keepalive_timeout() {
h2_support::trace_init!();
let (io, mut client) = mock::new();
let h1 = async {
let settings = client.assert_server_handshake().await;
assert_default_settings!(settings);
client
.send_frame(
frames::headers(1)
.request("GET", "https://example.com/")
.eos(),
)
.await;
client
.recv_frame(frames::headers(1).response(200).eos())
.await;
};

let mut builder = server::Builder::new();
builder.keepalive_timeout(Duration::from_secs(2));
let h2 = async move {
let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake");
let (req, mut stream) = srv.next().await.unwrap().unwrap();
assert_eq!(req.method(), &http::Method::GET);

let rsp = http::Response::builder().status(200).body(()).unwrap();
let res = stream.send_response(rsp, true).unwrap();
drop(res);
let r1 = srv.accept().await;
println!("rrr {r1:?}");
assert!(r1.is_some_and(|f| f.is_err_and(|f| f.is_go_away())));
};
join(h1, h2).await;
}
#[tokio::test]
async fn serve_request() {
h2_support::trace_init!();