diff --git a/src/configure.rs b/src/configure.rs index 4c1c77090b..823ca302c4 100644 --- a/src/configure.rs +++ b/src/configure.rs @@ -73,18 +73,16 @@ impl Context { self.sql.is_open().await, "cannot configure, database not opened." ); - let cancel_channel = self.alloc_ongoing().await?; + let ongoing_guard = self.alloc_ongoing().await?; let res = self .inner_configure() - .race(cancel_channel.recv().map(|_| { + .race(ongoing_guard.map(|_| { progress!(self, 0); Ok(()) })) .await; - self.free_ongoing().await; - if let Err(err) = res.as_ref() { progress!( self, diff --git a/src/context.rs b/src/context.rs index 3ff806e38e..58883aefd8 100644 --- a/src/context.rs +++ b/src/context.rs @@ -2,16 +2,18 @@ use std::collections::{BTreeMap, HashMap}; use std::ffi::OsString; +use std::future::Future; use std::ops::Deref; use std::path::{Path, PathBuf}; +use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use std::task::Poll; use std::time::{Duration, Instant, SystemTime}; use anyhow::{bail, ensure, Context as _, Result}; -use async_channel::{self as channel, Receiver, Sender}; use ratelimit::Ratelimit; -use tokio::sync::{Mutex, Notify, RwLock}; +use tokio::sync::{oneshot, Mutex, Notify, RwLock}; use crate::chat::{get_chat_cnt, ChatId}; use crate::config::Config; @@ -250,7 +252,7 @@ pub struct InnerContext { #[derive(Debug)] enum RunningState { /// Ongoing process is allocated. - Running { cancel_sender: Sender<()> }, + Running { cancel_sender: oneshot::Sender<()> }, /// Cancel signal has been sent, waiting for ongoing process to be freed. ShallStop { request: Instant }, @@ -502,51 +504,66 @@ impl Context { /// This is for modal operations during which no other user actions are allowed. Only /// one such operation is allowed at any given time. /// - /// The return value is a cancel token, which will release the ongoing mutex when - /// dropped. - pub(crate) async fn alloc_ongoing(&self) -> Result> { + /// The return value is a guard which does two things: + /// + /// - It is a Future which will complete when the ongoing process is cancelled using + /// [`Context::stop_ongoing`] and must stop. + /// - It will free the ongoing process, aka release the mutex, when dropped. + pub(crate) async fn alloc_ongoing(&self) -> Result { let mut s = self.running_state.write().await; ensure!( matches!(*s, RunningState::Stopped), "There is already another ongoing process running." ); - let (sender, receiver) = channel::bounded(1); + let (cancel_tx, cancel_rx) = oneshot::channel(); *s = RunningState::Running { - cancel_sender: sender, + cancel_sender: cancel_tx, }; + let (drop_tx, drop_rx) = oneshot::channel(); + let context = self.clone(); + + tokio::spawn(async move { + drop_rx.await.ok(); + let mut s = context.running_state.write().await; + if let RunningState::ShallStop { request } = *s { + info!(context, "Ongoing stopped in {:?}", request.elapsed()); + } + *s = RunningState::Stopped; + }); - Ok(receiver) - } - - pub(crate) async fn free_ongoing(&self) { - let mut s = self.running_state.write().await; - if let RunningState::ShallStop { request } = *s { - info!(self, "Ongoing stopped in {:?}", request.elapsed()); - } - *s = RunningState::Stopped; + Ok(OngoingGuard { + cancel_rx, + drop_tx: Some(drop_tx), + }) } /// Signal an ongoing process to stop. pub async fn stop_ongoing(&self) { let mut s = self.running_state.write().await; - match &*s { - RunningState::Running { cancel_sender } => { - if let Err(err) = cancel_sender.send(()).await { - warn!(self, "could not cancel ongoing: {:#}", err); - } - info!(self, "Signaling the ongoing process to stop ASAP.",); - *s = RunningState::ShallStop { - request: Instant::now(), - }; - } + + // Take out the state so we can call the oneshot sender (which takes ownership). + let current_state = std::mem::replace( + &mut *s, + RunningState::ShallStop { + request: Instant::now(), + }, + ); + + match current_state { + RunningState::Running { cancel_sender } => match cancel_sender.send(()) { + Ok(()) => info!(self, "Signaling the ongoing process to stop ASAP."), + Err(()) => warn!(self, "could not cancel ongoing"), + }, RunningState::ShallStop { .. } | RunningState::Stopped => { + // Put back the current state + *s = current_state; info!(self, "No ongoing process to stop.",); } } } - #[allow(unused)] + #[cfg(test)] pub(crate) async fn shall_stop_ongoing(&self) -> bool { match &*self.running_state.read().await { RunningState::Running { .. } => false, @@ -1034,6 +1051,54 @@ impl Context { } } +/// Guard received when calling [`Context::alloc_ongoing`]. +/// +/// While holding this guard the ongoing mutex is held, dropping this guard frees the +/// ongoing process. +/// +/// The ongoing process can also be cancelled by unrelated code calling +/// [`Context::stop_ongoing`]. This guard implements [`Future`] and the future will +/// complete when the ongoing process is cancelled and must be aborted. Freeing the ongoing +/// process works as usual in this case: when this guard is dropped. So if you need to do +/// some more work before freeing make sure to keep ownership of the guard, e.g.: +/// +/// ```no_compile +/// let mut guard = context.alloc_ongoing().await?; +/// tokio::select!{ +/// biased; +/// _ = &mut guard => (), // guard is not moved, so we keep ownership. +/// _ = do_work() => (), +/// }; +/// do_cleaup().await; +/// drop(guard); +/// ``` +pub(crate) struct OngoingGuard { + /// Receives a message when the ongoing process should be cancelled. + cancel_rx: oneshot::Receiver<()>, + /// Used by `Drop` to send a message which will free the ongoing process. + drop_tx: Option>, +} + +impl Future for OngoingGuard { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + match Pin::new(&mut self.cancel_rx).poll(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } +} + +impl Drop for OngoingGuard { + fn drop(&mut self) { + if let Some(sender) = self.drop_tx.take() { + // TODO: Maybe this should log? But we'd need to have a context. + sender.send(()).ok(); + } + } +} + /// Returns core version as a string. pub fn get_version_str() -> &'static str { &DC_VERSION_STR @@ -1524,38 +1589,52 @@ mod tests { async fn test_ongoing() -> Result<()> { let context = TestContext::new().await; - // No ongoing process allocated. + println!("No ongoing process allocated."); assert!(context.shall_stop_ongoing().await); - let receiver = context.alloc_ongoing().await?; + let mut guard = context.alloc_ongoing().await?; - // Cannot allocate another ongoing process while the first one is running. + println!("Cannot allocate another ongoing process while the first one is running."); assert!(context.alloc_ongoing().await.is_err()); - // Stop signal is not sent yet. - assert!(receiver.try_recv().is_err()); + println!("Stop signal is not sent yet."); + assert!(matches!(futures::poll!(&mut guard), Poll::Pending)); assert!(!context.shall_stop_ongoing().await); - // Send the stop signal. + println!("Send the stop signal."); context.stop_ongoing().await; - // Receive stop signal. - receiver.recv().await?; + println!("Receive stop signal."); + (&mut guard).await; assert!(context.shall_stop_ongoing().await); - // Ongoing process is still running even though stop signal was received, - // so another one cannot be allocated. + println!("Ongoing process still running even though stop signal was received"); assert!(context.alloc_ongoing().await.is_err()); - context.free_ongoing().await; - - // No ongoing process allocated, should have been stopped already. - assert!(context.shall_stop_ongoing().await); - - // Another ongoing process can be allocated now. - let _receiver = context.alloc_ongoing().await?; + println!("free the ongoing process"); + // context.free_ongoing().await; + drop(guard); + + println!("re-acquire the ongoing process"); + // Since the drop guard needs to send a message and the receiving task must run and + // acquire a lock this needs some time so won't succeed immediately. + #[allow(clippy::async_yields_async)] + let _guard = tokio::time::timeout(Duration::from_secs(10), async { + loop { + match context.alloc_ongoing().await { + Ok(guard) => break guard, + Err(_) => { + // tokio::task::yield_now() results in a lot hotter loop, it takes a + // lot of yields. + tokio::time::sleep(Duration::from_millis(1)).await; + } + } + } + }) + .await + .expect("timeout"); Ok(()) } diff --git a/src/imex.rs b/src/imex.rs index 236f5e6fcc..f919a2d90d 100644 --- a/src/imex.rs +++ b/src/imex.rs @@ -89,18 +89,17 @@ pub async fn imex( path: &Path, passphrase: Option, ) -> Result<()> { - let cancel = context.alloc_ongoing().await?; + let ongoing_guard = context.alloc_ongoing().await?; let res = { let _guard = context.scheduler.pause(context.clone()).await?; imex_inner(context, what, path, passphrase) .race(async { - cancel.recv().await.ok(); - Err(format_err!("canceled")) + ongoing_guard.await; + Err(format_err!("cancelled")) }) .await }; - context.free_ongoing().await; if let Err(err) = res.as_ref() { // We are using Anyhow's .context() and to show the inner error, too, we need the {:#}: diff --git a/src/imex/transfer.rs b/src/imex/transfer.rs index 545bd94bd5..74df42655b 100644 --- a/src/imex/transfer.rs +++ b/src/imex/transfer.rs @@ -30,7 +30,6 @@ use std::pin::Pin; use std::task::Poll; use anyhow::{anyhow, bail, ensure, format_err, Context as _, Result}; -use async_channel::Receiver; use futures_lite::StreamExt; use iroh::blobs::Collection; use iroh::get::DataStream; @@ -48,7 +47,7 @@ use tokio_util::sync::CancellationToken; use crate::blob::BlobDirContents; use crate::chat::{add_device_msg, delete_and_reset_all_device_msgs}; -use crate::context::Context; +use crate::context::{Context, OngoingGuard}; use crate::message::{Message, Viewtype}; use crate::qr::{self, Qr}; use crate::stock_str::backup_transfer_msg_body; @@ -98,8 +97,8 @@ impl BackupProvider { .context("Private key not available, aborting backup export")?; // Acquire global "ongoing" mutex. - let cancel_token = context.alloc_ongoing().await?; - let paused_guard = context.scheduler.pause(context.clone()).await?; + let mut ongoing_guard = context.alloc_ongoing().await?; + let paused_guard = context.scheduler.pause(context.clone()).await; let context_dir = context .get_blobdir() .parent() @@ -110,7 +109,7 @@ impl BackupProvider { warn!(context, "Previous database export deleted"); } let dbfile = TempPathGuard::new(dbfile); - let res = tokio::select! { + let (provider, ticket) = tokio::select! { biased; res = Self::prepare_inner(context, &dbfile) => { match res { @@ -121,22 +120,14 @@ impl BackupProvider { }, } }, - _ = cancel_token.recv() => Err(format_err!("cancelled")), - }; - let (provider, ticket) = match res { - Ok((provider, ticket)) => (provider, ticket), - Err(err) => { - context.free_ongoing().await; - return Err(err); - } - }; + _ = &mut ongoing_guard => Err(format_err!("cancelled")), + }?; let drop_token = CancellationToken::new(); let handle = { let context = context.clone(); let drop_token = drop_token.clone(); tokio::spawn(async move { - let res = Self::watch_provider(&context, provider, cancel_token, drop_token).await; - context.free_ongoing().await; + let res = Self::watch_provider(&context, provider, ongoing_guard, drop_token).await; // Explicit drop to move the guards into this future drop(paused_guard); @@ -201,7 +192,7 @@ impl BackupProvider { async fn watch_provider( context: &Context, mut provider: Provider, - cancel_token: Receiver<()>, + mut cancel_token: OngoingGuard, drop_token: CancellationToken, ) -> Result<()> { let mut events = provider.subscribe(); @@ -261,7 +252,7 @@ impl BackupProvider { } } }, - _ = cancel_token.recv() => { + _ = &mut cancel_token => { provider.shutdown(); break Err(anyhow!("BackupProvider cancelled")); }, @@ -394,20 +385,18 @@ pub async fn get_backup(context: &Context, qr: Qr) -> Result<()> { "Cannot import backups to accounts in use." ); // Acquire global "ongoing" mutex. - let cancel_token = context.alloc_ongoing().await?; + let mut cancel_token = context.alloc_ongoing().await?; let _guard = context.scheduler.pause(context.clone()).await; info!( context, "Running get_backup for {}", qr::format_backup(&qr)? ); - let res = tokio::select! { + tokio::select! { biased; res = get_backup_inner(context, qr) => res, - _ = cancel_token.recv() => Err(format_err!("cancelled")), - }; - context.free_ongoing().await; - res + _ = &mut cancel_token => Err(format_err!("cancelled")), + } } async fn get_backup_inner(context: &Context, qr: Qr) -> Result<()> {