diff --git a/.gitignore b/.gitignore index c2a55a1..42a9804 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ **/*.rs.bk tags .ccls-cache +.idea diff --git a/Cargo.lock b/Cargo.lock index fc3fa49..3ef6c45 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -810,6 +810,7 @@ version = "0.3.1" dependencies = [ "futures", "futures-core", + "futures-util", "libc", "serial_test 0.5.1", "signal-hook", diff --git a/signal-hook-async-std/src/lib.rs b/signal-hook-async-std/src/lib.rs index f7d0c12..8b7eeac 100644 --- a/signal-hook-async-std/src/lib.rs +++ b/signal-hook-async-std/src/lib.rs @@ -64,6 +64,8 @@ use signal_hook::iterator::exfiltrator::{Exfiltrator, SignalOnly}; use async_io::Async; use futures_lite::io::AsyncRead; use futures_lite::stream::Stream; +use futures_lite::StreamExt; +use signal_hook::consts; /// An asynchronous [`Stream`] of arriving signals. /// @@ -133,3 +135,20 @@ impl Stream for SignalsInfo { /// This one simply returns the signal numbers, while [`SignalsInfo`] can provide additional /// information. pub type Signals = SignalsInfo; + +/// Waits for the the process to receive a shutdown signal. +/// Waits for any of SIGHUP, SIGINT, SIGQUIT, and SIGTERM. +/// # Errors +/// Returns `Err` after failing to register the signal handler. +pub async fn wait_for_shutdown_signal() -> Result<(), String> { + let signals = [ + consts::SIGHUP, + consts::SIGINT, + consts::SIGQUIT, + consts::SIGTERM, + ]; + let mut signals = Signals::new(&signals) + .map_err(|e| format!("error setting up handler for signals {signals:?}: {e}"))?; + let _ = signals.next().await; + Ok(()) +} diff --git a/signal-hook-async-std/tests/async_std.rs b/signal-hook-async-std/tests/async_std.rs index 1996a81..bb256dd 100644 --- a/signal-hook-async-std/tests/async_std.rs +++ b/signal-hook-async-std/tests/async_std.rs @@ -1,10 +1,11 @@ use async_std::stream::StreamExt; +use std::convert::TryFrom; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; -use signal_hook::consts::SIGUSR1; +use signal_hook::consts::{SIGHUP, SIGUSR1}; use signal_hook::low_level::raise; use signal_hook_async_std::Signals; @@ -52,3 +53,25 @@ async fn delayed() { signals_task.await; assert!(recieved.load(Ordering::SeqCst)); } + +#[async_std::test] +#[serial] +async fn wait_for_shutdown_signal() { + let elapsed_ms = Arc::new(AtomicU64::new(0)); + let elapsed_ms_clone = Arc::clone(&elapsed_ms); + async_std::task::spawn(async move { + let before = Instant::now(); + signal_hook_async_std::wait_for_shutdown_signal() + .await + .unwrap(); + let elapsed_ms_u64 = + u64::try_from(Instant::now().saturating_duration_since(before).as_millis()) + .unwrap_or(u64::MAX); + elapsed_ms_clone.store(elapsed_ms_u64, Ordering::Release) + }); + async_std::task::sleep(Duration::from_millis(100)).await; + raise(SIGHUP).unwrap(); + async_std::task::sleep(Duration::from_millis(100)).await; + let elapsed_ms_u64 = elapsed_ms.load(Ordering::Acquire); + assert!((50..=150).contains(&elapsed_ms_u64), "{:?}", elapsed_ms_u64); +} diff --git a/signal-hook-tokio/Cargo.toml b/signal-hook-tokio/Cargo.toml index 5bb2284..b6a4f86 100644 --- a/signal-hook-tokio/Cargo.toml +++ b/signal-hook-tokio/Cargo.toml @@ -20,12 +20,14 @@ travis-ci = { repository = "vorner/signal-hook" } maintenance = { status = "actively-developed" } [features] +convenience = ["futures-v0_3", "futures-util"] futures-v0_3 = ["futures-core-0_3"] [dependencies] libc = "~0.2" signal-hook = { version = "~0.3", path = ".." } futures-core-0_3 = { package = "futures-core", version = "~0.3", optional = true } +futures-util = { version = "~0.3", features = [], optional = true } tokio = { version = "~1", features = ["net"] } [dev-dependencies] diff --git a/signal-hook-tokio/src/lib.rs b/signal-hook-tokio/src/lib.rs index 7d949be..ecbde2b 100644 --- a/signal-hook-tokio/src/lib.rs +++ b/signal-hook-tokio/src/lib.rs @@ -83,6 +83,9 @@ use signal_hook::iterator::exfiltrator::{Exfiltrator, SignalOnly}; #[cfg(feature = "futures-v0_3")] use futures_core_0_3::Stream; +#[cfg(feature = "convenience")] +use futures_util::StreamExt; +use signal_hook::consts; /// An asynchronous [`Stream`] of arriving signals. /// @@ -156,3 +159,21 @@ impl Stream for SignalsInfo { } } } + +#[cfg(feature = "convenience")] +/// Waits for the the process to receive a shutdown signal. +/// Waits for any of SIGHUP, SIGINT, SIGQUIT, and SIGTERM. +/// # Errors +/// Returns `Err` after failing to register the signal handler. +pub async fn wait_for_shutdown_signal() -> Result<(), String> { + let signals = [ + consts::SIGHUP, + consts::SIGINT, + consts::SIGQUIT, + consts::SIGTERM, + ]; + let mut signals = Signals::new(&signals) + .map_err(|e| format!("error setting up handler for signals {signals:?}: {e}"))?; + let _ = signals.next().await; + Ok(()) +} diff --git a/signal-hook-tokio/tests/tests.rs b/signal-hook-tokio/tests/tests.rs index acfcc43..319484b 100644 --- a/signal-hook-tokio/tests/tests.rs +++ b/signal-hook-tokio/tests/tests.rs @@ -2,11 +2,12 @@ use futures::stream::StreamExt; -use std::sync::atomic::{AtomicBool, Ordering}; +use std::convert::TryFrom; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::Arc; -use std::time::Duration; +use std::time::{Duration, Instant}; -use signal_hook::consts::SIGUSR1; +use signal_hook::consts::{SIGHUP, SIGUSR1}; use signal_hook::low_level::raise; use signal_hook_tokio::Signals; @@ -54,3 +55,23 @@ async fn delayed() { signals_task.await.unwrap(); assert!(recieved.load(Ordering::SeqCst)); } + +#[tokio::test] +#[serial] +async fn wait_for_shutdown_signal() { + let elapsed_ms = Arc::new(AtomicU64::new(0)); + let elapsed_ms_clone = Arc::clone(&elapsed_ms); + tokio::spawn(async move { + let before = Instant::now(); + signal_hook_tokio::wait_for_shutdown_signal().await.unwrap(); + let elapsed_ms_u64 = + u64::try_from(Instant::now().saturating_duration_since(before).as_millis()) + .unwrap_or(u64::MAX); + elapsed_ms_clone.store(elapsed_ms_u64, Ordering::Release) + }); + tokio::time::sleep(Duration::from_millis(100)).await; + raise(SIGHUP).unwrap(); + tokio::time::sleep(Duration::from_millis(100)).await; + let elapsed_ms_u64 = elapsed_ms.load(Ordering::Acquire); + assert!((50..=150).contains(&elapsed_ms_u64), "{:?}", elapsed_ms_u64); +} diff --git a/src/lib.rs b/src/lib.rs index d750e15..57eb71a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -412,3 +412,20 @@ pub mod consts { } pub use signal_hook_registry::SigId; + +/// Waits for the the process to receive a shutdown signal. +/// Waits for any of SIGHUP, SIGINT, SIGQUIT, and SIGTERM. +/// # Errors +/// Returns `Err` after failing to register the signal handler. +pub fn wait_for_shutdown_signal() -> Result<(), String> { + let signals = [ + consts::SIGHUP, + consts::SIGINT, + consts::SIGQUIT, + consts::SIGTERM, + ]; + let mut signals = iterator::Signals::new(&signals) + .map_err(|e| format!("error setting up handler for signals {signals:?}: {e}"))?; + let _ = signals.forever().next(); + Ok(()) +} diff --git a/tests/test.rs b/tests/test.rs new file mode 100644 index 0000000..67546a3 --- /dev/null +++ b/tests/test.rs @@ -0,0 +1,20 @@ +use std::time::{Duration, Instant}; + +#[test] +fn wait_for_shutdown_signal() { + let (sender, receiver) = std::sync::mpsc::sync_channel(1); + std::thread::spawn(move || { + let before = Instant::now(); + signal_hook::wait_for_shutdown_signal().unwrap(); + let elapsed = Instant::now().saturating_duration_since(before); + sender.send(elapsed).unwrap(); + }); + std::thread::sleep(Duration::from_millis(100)); + signal_hook::low_level::raise(signal_hook::consts::SIGHUP).unwrap(); + let elapsed = receiver.recv_timeout(Duration::from_millis(100)).unwrap(); + assert!( + (Duration::from_millis(50)..=Duration::from_millis(150)).contains(&elapsed), + "{:?}", + elapsed + ); +}