diff --git a/Cargo.toml b/Cargo.toml index 1e5fcea05..d936c3903 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,6 +73,10 @@ harness = false name = "accumulator" harness = false +[[bench]] +name = "dataflow" +harness = false + [workspace] members = ["components/salsa-macro-rules", "components/salsa-macros"] diff --git a/benches/dataflow.rs b/benches/dataflow.rs new file mode 100644 index 000000000..d535046e9 --- /dev/null +++ b/benches/dataflow.rs @@ -0,0 +1,170 @@ +//! Benchmark for fixpoint iteration cycle resolution. +//! +//! This benchmark simulates a (very simplified) version of a real dataflow analysis using fixpoint +//! iteration. +use codspeed_criterion_compat::{criterion_group, criterion_main, BatchSize, Criterion}; +use salsa::{CycleRecoveryAction, Database as Db, Setter}; +use std::collections::BTreeSet; +use std::iter::IntoIterator; + +/// A Use of a symbol. +#[salsa::input] +struct Use { + reaching_definitions: Vec, +} + +/// A Definition of a symbol, either of the form `base + increment` or `0 + increment`. +#[salsa::input] +struct Definition { + base: Option, + increment: usize, +} + +#[derive(Eq, PartialEq, Clone, Debug, salsa::Update)] +enum Type { + Bottom, + Values(Box<[usize]>), + Top, +} + +impl Type { + fn join(tys: impl IntoIterator) -> Type { + let mut result = Type::Bottom; + for ty in tys.into_iter() { + result = match (result, ty) { + (result, Type::Bottom) => result, + (_, Type::Top) => Type::Top, + (Type::Top, _) => Type::Top, + (Type::Bottom, ty) => ty, + (Type::Values(a_ints), Type::Values(b_ints)) => { + let mut set = BTreeSet::new(); + set.extend(a_ints); + set.extend(b_ints); + Type::Values(set.into_iter().collect()) + } + } + } + result + } +} + +#[salsa::tracked(cycle_fn=use_cycle_recover, cycle_initial=use_cycle_initial)] +fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type { + let defs = u.reaching_definitions(db); + match defs[..] { + [] => Type::Bottom, + [def] => infer_definition(db, def), + _ => Type::join(defs.iter().map(|&def| infer_definition(db, def))), + } +} + +#[salsa::tracked(cycle_fn=def_cycle_recover, cycle_initial=def_cycle_initial)] +fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { + let increment_ty = Type::Values(Box::from([def.increment(db)])); + if let Some(base) = def.base(db) { + let base_ty = infer_use(db, base); + add(&base_ty, &increment_ty) + } else { + increment_ty + } +} + +fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type { + Type::Bottom +} + +fn def_cycle_recover( + _db: &dyn Db, + value: &Type, + count: u32, + _def: Definition, +) -> CycleRecoveryAction { + cycle_recover(value, count) +} + +fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type { + Type::Bottom +} + +fn use_cycle_recover( + _db: &dyn Db, + value: &Type, + count: u32, + _use: Use, +) -> CycleRecoveryAction { + cycle_recover(value, count) +} + +fn cycle_recover(value: &Type, count: u32) -> CycleRecoveryAction { + match value { + Type::Bottom => CycleRecoveryAction::Iterate, + Type::Values(_) => { + if count > 4 { + CycleRecoveryAction::Fallback(Type::Top) + } else { + CycleRecoveryAction::Iterate + } + } + Type::Top => CycleRecoveryAction::Iterate, + } +} + +fn add(a: &Type, b: &Type) -> Type { + match (a, b) { + (Type::Bottom, _) | (_, Type::Bottom) => Type::Bottom, + (Type::Top, _) | (_, Type::Top) => Type::Top, + (Type::Values(a_ints), Type::Values(b_ints)) => { + let mut set = BTreeSet::new(); + set.extend( + a_ints + .into_iter() + .flat_map(|a| b_ints.into_iter().map(move |b| a + b)), + ); + Type::Values(set.into_iter().collect()) + } + } +} + +fn dataflow(criterion: &mut Criterion) { + criterion.bench_function("converge_diverge", |b| { + b.iter_batched_ref( + || { + let mut db = salsa::DatabaseImpl::new(); + + let defx0 = Definition::new(&db, None, 0); + let defy0 = Definition::new(&db, None, 0); + let defx1 = Definition::new(&db, None, 0); + let defy1 = Definition::new(&db, None, 0); + let use_x = Use::new(&db, vec![defx0, defx1]); + let use_y = Use::new(&db, vec![defy0, defy1]); + defx1.set_base(&mut db).to(Some(use_y)); + defy1.set_base(&mut db).to(Some(use_x)); + + // prewarm cache + let _ = infer_use(&db, use_x); + let _ = infer_use(&db, use_y); + + (db, defx1, use_x, use_y) + }, + |(db, defx1, use_x, use_y)| { + // Set the increment on x to 0. + defx1.set_increment(db).to(0); + + // Both symbols converge on 0. + assert_eq!(infer_use(db, *use_x), Type::Values(Box::from([0]))); + assert_eq!(infer_use(db, *use_y), Type::Values(Box::from([0]))); + + // Set the increment on x to 1. + defx1.set_increment(db).to(1); + + // Now the loop diverges and we fall back to Top. + assert_eq!(infer_use(db, *use_x), Type::Top); + assert_eq!(infer_use(db, *use_y), Type::Top); + }, + BatchSize::LargeInput, + ); + }); +} + +criterion_group!(benches, dataflow); +criterion_main!(benches); diff --git a/book/src/SUMMARY.md b/book/src/SUMMARY.md index 675dd3e0b..4602916ed 100644 --- a/book/src/SUMMARY.md +++ b/book/src/SUMMARY.md @@ -22,7 +22,6 @@ - [On-demand (Lazy) inputs](./common_patterns/on_demand_inputs.md) - [Tuning](./tuning.md) - [Cycle handling](./cycles.md) - - [Recovering via fallback](./cycles/fallback.md) # How Salsa works internally diff --git a/book/src/cycles.md b/book/src/cycles.md index 507dbde02..8222d9eaf 100644 --- a/book/src/cycles.md +++ b/book/src/cycles.md @@ -1,5 +1,40 @@ # Cycle handling -By default, when Salsa detects a cycle in the computation graph, Salsa will panic with a [`salsa::Cycle`] as the panic value. The [`salsa::Cycle`] structure that describes the cycle, which can be useful for diagnosing what went wrong. +By default, when Salsa detects a cycle in the computation graph, Salsa will panic with a message naming the "cycle head"; this is the query that was called while it was also on the active query stack, creating a cycle. -[`salsa::cycle`]: https://github.com/salsa-rs/salsa/blob/0f9971ad94d5d137f1192fde2b02ccf1d2aca28c/src/lib.rs#L654-L672 +Salsa also supports recovering from query cycles via fixed-point iteration. Fixed-point iteration is only usable if the queries which may be involved in a cycle are monotone and operate on a value domain which is a partial order with fixed height. Effectively, this means that the queries' output must always be "larger" than its input, and there must be some "maximum" or "top" value. This ensures that fixed-point iteration will converge to a value. (A typical case would be queries operating on types, which form a partial order with a "top" type.) + +In order to support fixed-point iteration for a query, provide the `cycle_fn` and `cycle_initial` arguments to `salsa::tracked`: + +```rust +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial_fn)] +fn query(db: &dyn salsa::Database) -> u32 { + // ... +} + +fn cycle_fn(_db: &dyn KnobsDatabase, _value: &u32, _count: u32) -> salsa::CycleRecoveryAction { + salsa::CycleRecoveryAction::Iterate +} + +fn initial(_db: &dyn KnobsDatabase) -> u32 { + 0 +} +``` + +If `query` becomes the head of a cycle (that is, `query` is executing and on the active query stack, it calls `query2`, `query2` calls `query3`, and `query3` calls `query` again -- there could be any number of queries involved in the cycle), the `initial_fn` will be called to generate an "initial" value for `query` in the fixed-point computation. (The initial value should usually be the "bottom" value in the partial order.) All queries in the cycle will compute a provisional result based on this initial value for the cycle head. That is, `query3` will compute a provisional result using the initial value for `query`, `query2` will compute a provisional result using this provisional value for `query3`. When `cycle2` returns its provisional result back to `cycle`, `cycle` will observe that it has received a provisional result from its own cycle, and will call the `cycle_fn` (with the current value and the number of iterations that have occurred so far). The `cycle_fn` can return `salsa::CycleRecoveryAction::Iterate` to indicate that the cycle should iterate again, or `salsa::CycleRecoveryAction::Fallback(value)` to indicate that the cycle should stop iterating and fall back to the value provided. + +If the `cycle_fn` continues to return `Iterate`, the cycle will iterate until it converges: that is, until two successive iterations produce the same result. + +If the `cycle_fn` returns `Fallback`, the cycle will iterate one last time and verify that the returned value is the same as the fallback value; that is, the fallback value results in a stable converged cycle. If not, Salsa will panic. It is not permitted to use a fallback value that does not converge, because this would leave the cycle in an unpredictable state, depending on the order of query execution. + +## All potential cycle heads must set `cycle_fn` and `cycle_initial` + +Consider a two-query cycle where `query_a` calls `query_b`, and `query_b` calls `query_a`. If `query_a` is called first, then it will become the "cycle head", but if `query_b` is called first, then `query_b` will be the cycle head. In order for a cycle to use fixed-point iteration instead of panicking, the cycle head must set `cycle_fn` and `cycle_initial`. This means that in order to be robust against varying query execution order, both `query_a` and `query_b` must set `cycle_fn` and `cycle_initial`. + +## Ensuring convergence + +Fixed-point iteration is a powerful tool, but is also easy to misuse, potentially resulting in infinite iteration. To avoid this, ensure that all queries participating in fixpoint iteration are deterministic and monotone. + +## Calling Salsa queries from within `cycle_fn` or `cycle_initial` + +It is permitted to call other Salsa queries from within the `cycle_fn` and `cycle_initial` functions. However, if these functions re-enter the same cycle, this can lead to unpredictable results. Take care which queries are called from within cycle-recovery functions, and avoid triggering further cycles. diff --git a/book/src/cycles/fallback.md b/book/src/cycles/fallback.md deleted file mode 100644 index d3971ece5..000000000 --- a/book/src/cycles/fallback.md +++ /dev/null @@ -1,21 +0,0 @@ -# Recovering via fallback - -Panicking when a cycle occurs is ok for situations where you believe a cycle is impossible. But sometimes cycles can result from illegal user input and cannot be statically prevented. In these cases, you might prefer to gracefully recover from a cycle rather than panicking the entire query. Salsa supports that with the idea of *cycle recovery*. - -To use cycle recovery, you annotate potential participants in the cycle with the `recovery_fn` argument to `#[salsa::tracked]`, e.g. `#[salsa::tracked(recovery_fn=my_recovery_fn)]`. When a cycle occurs, if any participant P has recovery information, then no panic occurs. Instead, the execution of P is aborted and P will execute the recovery function to generate its result. Participants in the cycle that do not have recovery information continue executing as normal, using this recovery result. - -The recovery function has a similar signature to a query function. It is given a reference to your database along with a `salsa::Cycle` describing the cycle that occurred and the arguments to the tracked function that caused the cycle; it returns the result of the query. Example: - -```rust -fn my_recover_fn( - db: &dyn MyDatabase, - cycle: &salsa::Cycle, - arg1: T1, - ... - argN: TN, -) -> MyResultValue -``` - -See [the tests](https://github.com/salsa-rs/salsa/blob/cd339fc1c9a6ea0ffb1d09bd3bffb5633f776ef3/tests/cycles.rs#L132-L141) for an example. - -**Important:** Although the recovery function is given a `db` handle, you should be careful to avoid creating a cycle from within recovery or invoking queries that may be participating in the current cycle. Attempting to do so can result in inconsistent results. diff --git a/book/src/plumbing/cycles.md b/book/src/plumbing/cycles.md index ca379c5dc..39a8f8201 100644 --- a/book/src/plumbing/cycles.md +++ b/book/src/plumbing/cycles.md @@ -15,51 +15,3 @@ When a thread `T1` attempts to execute a query `Q`, it will try to load the valu * Otherwise, if `Q` is being computed by some other thread `T2`, we have to check whether `T2` is (transitively) blocked on `T1`. If so, there is a cycle. These two cases are handled internally by the `Runtime::try_block_on` function. Detecting the intra-thread cycle case is easy; to detect cross-thread cycles, the runtime maintains a dependency DAG between threads (identified by `RuntimeId`). Before adding an edge `T1 -> T2` (i.e., `T1` is blocked waiting for `T2`) into the DAG, it checks whether a path exists from `T2` to `T1`. If so, we have a cycle and the edge cannot be added (then the DAG would not longer be acyclic). - -When a cycle is detected, the current thread `T1` has full access to the query stacks that are participating in the cycle. Consider: naturally, `T1` has access to its own stack. There is also a path `T2 -> ... -> Tn -> T1` of blocked threads. Each of the blocked threads `T2 ..= Tn` will have moved their query stacks into the dependency graph, so those query stacks are available for inspection. - -Using the available stacks, we can create a list of cycle participants `Q0 ... Qn` and store that into a `Cycle` struct. If none of the participants `Q0 ... Qn` have cycle recovery enabled, we panic with the `Cycle` struct, which will trigger all the queries on this thread to panic. - -## Cycle recovery via fallback - -If any of the cycle participants `Q0 ... Qn` has cycle recovery set, we recover from the cycle. To help explain how this works, we will use this example cycle which contains three threads. Beginning with the current query, the cycle participants are `QA3`, `QB2`, `QB3`, `QC2`, `QC3`, and `QA2`. - -``` - The cyclic - edge we have - failed to add. - : - A : B C - : - QA1 v QB1 QC1 -┌► QA2 ┌──► QB2 ┌─► QC2 -│ QA3 ───┘ QB3 ──┘ QC3 ───┐ -│ │ -└───────────────────────────────┘ -``` - -Recovery works in phases: - -* **Analyze:** As we enumerate the query participants, we collect their collective inputs (all queries invoked so far by any cycle participant) and the max changed-at and min duration. We then remove the cycle participants themselves from this list of inputs, leaving only the queries external to the cycle. -* **Mark**: For each query Q that is annotated with `#[salsa::cycle]`, we mark it and all of its successors on the same thread by setting its `cycle` flag to the `c: Cycle` we constructed earlier; we also reset its inputs to the collective inputs gathering during analysis. If those queries resume execution later, those marks will trigger them to immediately unwind and use cycle recovery, and the inputs will be used as the inputs to the recovery value. - * Note that we mark *all* the successors of Q on the same thread, whether or not they have recovery set. We'll discuss later how this is important in the case where the active thread (A, here) doesn't have any recovery set. -* **Unblock**: Each blocked thread T that has a recovering query is forcibly reawoken; the outgoing edge from that thread to its successor in the cycle is removed. Its condvar is signalled with a `WaitResult::Cycle(c)`. When the thread reawakens, it will see that and start unwinding with the cycle `c`. -* **Handle the current thread:** Finally, we have to choose how to have the current thread proceed. If the current thread includes any cycle with recovery information, then we can begin unwinding. Otherwise, the current thread simply continues as if there had been no cycle, and so the cyclic edge is added to the graph and the current thread blocks. This is possible because some other thread had recovery information and therefore has been awoken. - -Let's walk through the process with a few examples. - -### Example 1: Recovery on the detecting thread - -Consider the case where only the query QA2 has recovery set. It and QA3 will be marked with their `cycle` flag set to `c: Cycle`. Threads B and C will not be unblocked, as they do not have any cycle recovery nodes. The current thread (Thread A) will initiate unwinding with the cycle `c` as the value. Unwinding will pass through QA3 and be caught by QA2. QA2 will substitute the recovery value and return normally. QA1 and QC3 will then complete normally and so forth, on up until all queries have completed. - -### Example 2: Recovery in two queries on the detecting thread - -Consider the case where both query QA2 and QA3 have recovery set. It proceeds the same Example 1 until the current thread initiates unwinding, as described in Example 1. When QA3 receives the cycle, it stores its recovery value and completes normally. QA2 then adds QA3 as an input dependency: at that point, QA2 observes that it too has the cycle mark set, and so it initiates unwinding. The rest of QA2 therefore never executes. This unwinding is caught by QA2's entry point and it stores the recovery value and returns normally. QA1 and QC3 then continue normally, as they have not had their `cycle` flag set. - -### Example 3: Recovery on another thread - -Now consider the case where only the query QB2 has recovery set. It and QB3 will be marked with the cycle `c: Cycle` and thread B will be unblocked; the edge `QB3 -> QC2` will be removed from the dependency graph. Thread A will then add an edge `QA3 -> QB2` and block on thread B. At that point, thread A releases the lock on the dependency graph, and so thread B is re-awoken. It observes the `WaitResult::Cycle` and initiates unwinding. Unwinding proceeds through QB3 and into QB2, which recovers. QB1 is then able to execute normally, as is QA3, and execution proceeds from there. - -### Example 4: Recovery on all queries - -Now consider the case where all the queries have recovery set. In that case, they are all marked with the cycle, and all the cross-thread edges are removed from the graph. Each thread will independently awaken and initiate unwinding. Each query will recover. diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index fcc641863..efc27398f 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -37,6 +37,9 @@ macro_rules! setup_tracked_fn { // Path to the cycle recovery function to use. cycle_recovery_fn: ($($cycle_recovery_fn:tt)*), + // Path to function to get the initial value to use for cycle recovery. + cycle_recovery_initial: ($($cycle_recovery_initial:tt)*), + // Name of cycle recovery strategy variant to use. cycle_recovery_strategy: $cycle_recovery_strategy:ident, @@ -180,7 +183,7 @@ macro_rules! setup_tracked_fn { const CYCLE_STRATEGY: $zalsa::CycleRecoveryStrategy = $zalsa::CycleRecoveryStrategy::$cycle_recovery_strategy; - fn should_backdate_value( + fn values_equal( old_value: &Self::Output<'_>, new_value: &Self::Output<'_>, ) -> bool { @@ -188,7 +191,7 @@ macro_rules! setup_tracked_fn { if $no_eq { false } else { - $zalsa::should_backdate_value(old_value, new_value) + $zalsa::values_equal(old_value, new_value) } } } @@ -201,12 +204,17 @@ macro_rules! setup_tracked_fn { $inner($db, $($input_id),*) } + fn cycle_initial<$db_lt>(db: &$db_lt dyn $Db, ($($input_id),*): ($($input_ty),*)) -> Self::Output<$db_lt> { + $($cycle_recovery_initial)*(db, $($input_id),*) + } + fn recover_from_cycle<$db_lt>( db: &$db_lt dyn $Db, - cycle: &$zalsa::Cycle, + value: &Self::Output<$db_lt>, + count: u32, ($($input_id),*): ($($input_ty),*) - ) -> Self::Output<$db_lt> { - $($cycle_recovery_fn)*(db, cycle, $($input_id),*) + ) -> $zalsa::CycleRecoveryAction> { + $($cycle_recovery_fn)*(db, value, count, $($input_id),*) } fn id_to_input<$db_lt>(db: &$db_lt Self::DbView, key: salsa::Id) -> Self::Input<$db_lt> { diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index a8b8122b3..a1cd1e73f 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -3,11 +3,18 @@ // a macro because it can take a variadic number of arguments. #[macro_export] macro_rules! unexpected_cycle_recovery { - ($db:ident, $cycle:ident, $($other_inputs:ident),*) => { - { - std::mem::drop($db); - std::mem::drop(($($other_inputs),*)); - panic!("cannot recover from cycle `{:?}`", $cycle) - } - } + ($db:ident, $value:ident, $count:ident, $($other_inputs:ident),*) => {{ + std::mem::drop($db); + std::mem::drop(($($other_inputs),*)); + panic!("cannot recover from cycle") + }}; +} + +#[macro_export] +macro_rules! unexpected_cycle_initial { + ($db:ident, $($other_inputs:ident),*) => {{ + std::mem::drop($db); + std::mem::drop(($($other_inputs),*)); + panic!("no cycle initial value") + }}; } diff --git a/components/salsa-macros/src/accumulator.rs b/components/salsa-macros/src/accumulator.rs index e84bae121..2885e131b 100644 --- a/components/salsa-macros/src/accumulator.rs +++ b/components/salsa-macros/src/accumulator.rs @@ -40,7 +40,8 @@ impl AllowedOptions for Accumulator { const SINGLETON: bool = false; const DATA: bool = false; const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + const CYCLE_INITIAL: bool = false; const LRU: bool = false; const CONSTRUCTOR_NAME: bool = false; const ID: bool = false; diff --git a/components/salsa-macros/src/input.rs b/components/salsa-macros/src/input.rs index e3e560520..cc330a584 100644 --- a/components/salsa-macros/src/input.rs +++ b/components/salsa-macros/src/input.rs @@ -52,7 +52,9 @@ impl crate::options::AllowedOptions for InputStruct { const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + + const CYCLE_INITIAL: bool = false; const LRU: bool = false; diff --git a/components/salsa-macros/src/interned.rs b/components/salsa-macros/src/interned.rs index 30d89f8fb..dea7116ce 100644 --- a/components/salsa-macros/src/interned.rs +++ b/components/salsa-macros/src/interned.rs @@ -53,7 +53,9 @@ impl crate::options::AllowedOptions for InternedStruct { const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + + const CYCLE_INITIAL: bool = false; const LRU: bool = false; diff --git a/components/salsa-macros/src/options.rs b/components/salsa-macros/src/options.rs index 49c7e0eea..cbb4ecaf6 100644 --- a/components/salsa-macros/src/options.rs +++ b/components/salsa-macros/src/options.rs @@ -50,10 +50,15 @@ pub(crate) struct Options { /// If this is `Some`, the value is the ``. pub db_path: Option, - /// The `recovery_fn = ` option is used to indicate the recovery function. + /// The `cycle_fn = ` option is used to indicate the cycle recovery function. /// /// If this is `Some`, the value is the ``. - pub recovery_fn: Option, + pub cycle_fn: Option, + + /// The `cycle_initial = ` option is the initial value for cycle iteration. + /// + /// If this is `Some`, the value is the ``. + pub cycle_initial: Option, /// The `data = ` option is used to define the name of the data type for an interned /// struct. @@ -92,7 +97,8 @@ impl Default for Options { no_lifetime: Default::default(), no_clone: Default::default(), db_path: Default::default(), - recovery_fn: Default::default(), + cycle_fn: Default::default(), + cycle_initial: Default::default(), data: Default::default(), constructor_name: Default::default(), phantom: Default::default(), @@ -114,7 +120,8 @@ pub(crate) trait AllowedOptions { const SINGLETON: bool; const DATA: bool; const DB: bool; - const RECOVERY_FN: bool; + const CYCLE_FN: bool; + const CYCLE_INITIAL: bool; const LRU: bool; const CONSTRUCTOR_NAME: bool; const ID: bool; @@ -237,20 +244,36 @@ impl syn::parse::Parse for Options { "`db` option not allowed here", )); } - } else if ident == "recovery_fn" { - if A::RECOVERY_FN { + } else if ident == "cycle_fn" { + if A::CYCLE_FN { + let _eq = Equals::parse(input)?; + let path = syn::Path::parse(input)?; + if let Some(old) = options.cycle_fn.replace(path) { + return Err(syn::Error::new( + old.span(), + "option `cycle_fn` provided twice", + )); + } + } else { + return Err(syn::Error::new( + ident.span(), + "`cycle_fn` option not allowed here", + )); + } + } else if ident == "cycle_initial" { + if A::CYCLE_INITIAL { let _eq = Equals::parse(input)?; let path = syn::Path::parse(input)?; - if let Some(old) = options.recovery_fn.replace(path) { + if let Some(old) = options.cycle_initial.replace(path) { return Err(syn::Error::new( old.span(), - "option `recovery_fn` provided twice", + "option `cycle_initial` provided twice", )); } } else { return Err(syn::Error::new( ident.span(), - "`recovery_fn` option not allowed here", + "`cycle_initial` option not allowed here", )); } } else if ident == "data" { diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index 389d35fb8..178ed162e 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -41,7 +41,9 @@ impl crate::options::AllowedOptions for TrackedFn { const DB: bool = false; - const RECOVERY_FN: bool = true; + const CYCLE_FN: bool = true; + + const CYCLE_INITIAL: bool = true; const LRU: bool = true; @@ -72,9 +74,20 @@ impl Macro { let input_ids = self.input_ids(&item); let input_tys = self.input_tys(&item)?; let output_ty = self.output_ty(&db_lt, &item)?; - let (cycle_recovery_fn, cycle_recovery_strategy) = self.cycle_recovery(); + let (cycle_recovery_fn, cycle_recovery_initial, cycle_recovery_strategy) = + self.cycle_recovery()?; let is_specifiable = self.args.specify.is_some(); - let no_eq = self.args.no_eq.is_some(); + let no_eq = if let Some(token) = &self.args.no_eq { + if self.args.cycle_fn.is_some() { + return Err(syn::Error::new_spanned( + token, + "the `no_eq` option cannot be used with `cycle_fn`", + )); + } + true + } else { + false + }; let mut inner_fn = item.clone(); inner_fn.vis = syn::Visibility::Inherited; @@ -146,6 +159,7 @@ impl Macro { output_ty: #output_ty, inner_fn: { #inner_fn }, cycle_recovery_fn: #cycle_recovery_fn, + cycle_recovery_initial: #cycle_recovery_initial, cycle_recovery_strategy: #cycle_recovery_strategy, is_specifiable: #is_specifiable, no_eq: #no_eq, @@ -181,14 +195,28 @@ impl Macro { Ok(ValidFn { db_ident, db_path }) } - fn cycle_recovery(&self) -> (TokenStream, TokenStream) { - if let Some(recovery_fn) = &self.args.recovery_fn { - (quote!((#recovery_fn)), quote!(Fallback)) - } else { - ( + fn cycle_recovery(&self) -> syn::Result<(TokenStream, TokenStream, TokenStream)> { + // TODO should we ask the user to specify a struct that impls a trait with two methods, + // rather than asking for two methods separately? + match (&self.args.cycle_fn, &self.args.cycle_initial) { + (Some(cycle_fn), Some(cycle_initial)) => Ok(( + quote!((#cycle_fn)), + quote!((#cycle_initial)), + quote!(Fixpoint), + )), + (None, None) => Ok(( quote!((salsa::plumbing::unexpected_cycle_recovery!)), + quote!((salsa::plumbing::unexpected_cycle_initial!)), quote!(Panic), - ) + )), + (Some(_), None) => Err(syn::Error::new_spanned( + self.args.cycle_fn.as_ref().unwrap(), + "must provide `cycle_initial` along with `cycle_fn`", + )), + (None, Some(_)) => Err(syn::Error::new_spanned( + self.args.cycle_initial.as_ref().unwrap(), + "must provide `cycle_fn` along with `cycle_initial`", + )), } } diff --git a/components/salsa-macros/src/tracked_struct.rs b/components/salsa-macros/src/tracked_struct.rs index bddba78bc..be8ccdf5f 100644 --- a/components/salsa-macros/src/tracked_struct.rs +++ b/components/salsa-macros/src/tracked_struct.rs @@ -47,7 +47,9 @@ impl crate::options::AllowedOptions for TrackedStruct { const DB: bool = false; - const RECOVERY_FN: bool = false; + const CYCLE_FN: bool = false; + + const CYCLE_INITIAL: bool = false; const LRU: bool = false; diff --git a/src/accumulator.rs b/src/accumulator.rs index e6f7e640d..a470c26a6 100644 --- a/src/accumulator.rs +++ b/src/accumulator.rs @@ -12,7 +12,8 @@ use accumulated::AnyAccumulated; use crate::{ cycle::CycleRecoveryStrategy, - ingredient::{fmt_index, Ingredient, Jar, MaybeChangedAfter}, + function::VerifyResult, + ingredient::{fmt_index, Ingredient, Jar}, plumbing::IngredientIndices, zalsa::{IngredientIndex, Zalsa}, zalsa_local::QueryOrigin, @@ -105,10 +106,18 @@ impl Ingredient for IngredientImpl { _db: &dyn Database, _input: Id, _revision: Revision, - ) -> MaybeChangedAfter { + ) -> VerifyResult { panic!("nothing should ever depend on an accumulator directly") } + fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + + fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { + true + } + fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { CycleRecoveryStrategy::Panic } @@ -130,6 +139,7 @@ impl Ingredient for IngredientImpl { _db: &dyn Database, _executor: DatabaseKeyIndex, _stale_output_key: crate::Id, + _provisional: bool, ) { } diff --git a/src/active_query.rs b/src/active_query.rs index fe4f7a351..9156d2e3f 100644 --- a/src/active_query.rs +++ b/src/active_query.rs @@ -7,11 +7,12 @@ use crate::tracked_struct::{DisambiguatorMap, IdentityHash, IdentityMap}; use crate::zalsa_local::QueryEdge; use crate::{ accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}, + cycle::CycleHeads, durability::Durability, hash::FxIndexSet, key::{DatabaseKeyIndex, InputDependencyIndex}, tracked_struct::Disambiguator, - Cycle, Revision, + Revision, }; #[derive(Debug)] @@ -37,9 +38,6 @@ pub(crate) struct ActiveQuery { /// True if there was an untracked read. untracked_read: bool, - /// Stores the entire cycle, if one is found and this query is part of it. - pub(crate) cycle: Option, - /// When new tracked structs are created, their data is hashed, and the resulting /// hash is added to this map. If it is not present, then the disambiguator is 0. /// Otherwise it is 1 more than the current value (which is incremented). @@ -60,6 +58,9 @@ pub(crate) struct ActiveQuery { /// [`InputAccumulatedValues::Empty`] if any input read during the query's execution /// has any accumulated values. pub(super) accumulated_inputs: InputAccumulatedValues, + + /// Provisional cycle results that this query depends on. + pub(crate) cycle_heads: CycleHeads, } impl ActiveQuery { @@ -70,11 +71,11 @@ impl ActiveQuery { changed_at: Revision::start(), input_outputs: FxIndexSet::default(), untracked_read: false, - cycle: None, disambiguator_map: Default::default(), tracked_struct_ids: Default::default(), accumulated: Default::default(), accumulated_inputs: Default::default(), + cycle_heads: Default::default(), } } @@ -84,11 +85,13 @@ impl ActiveQuery { durability: Durability, revision: Revision, accumulated: InputAccumulatedValues, + cycle_heads: &CycleHeads, ) { self.input_outputs.insert(QueryEdge::Input(input)); self.durability = self.durability.min(durability); self.changed_at = self.changed_at.max(revision); self.accumulated_inputs |= accumulated; + self.cycle_heads.extend(cycle_heads); } pub(super) fn add_untracked_read(&mut self, changed_at: Revision) { @@ -132,36 +135,10 @@ impl ActiveQuery { tracked_struct_ids: self.tracked_struct_ids, accumulated_inputs: AtomicInputAccumulatedValues::new(self.accumulated_inputs), accumulated, + cycle_heads: self.cycle_heads, } } - /// Adds any dependencies from `other` into `self`. - /// Used during cycle recovery, see [`Runtime::unblock_cycle_and_maybe_throw`]. - pub(super) fn add_from(&mut self, other: &ActiveQuery) { - self.changed_at = self.changed_at.max(other.changed_at); - self.durability = self.durability.min(other.durability); - self.untracked_read |= other.untracked_read; - self.input_outputs - .extend(other.input_outputs.iter().copied()); - } - - /// Removes the participants in `cycle` from my dependencies. - /// Used during cycle recovery, see [`Runtime::unblock_cycle_and_maybe_throw`]. - pub(super) fn remove_cycle_participants(&mut self, cycle: &Cycle) { - for p in cycle.participant_keys() { - let p: InputDependencyIndex = p.into(); - self.input_outputs.shift_remove(&QueryEdge::Input(p)); - } - } - - /// Copy the changed-at, durability, and dependencies from `cycle_query`. - /// Used during cycle recovery, see [`Runtime::unblock_cycle_and_maybe_throw`]. - pub(crate) fn take_inputs_from(&mut self, cycle_query: &ActiveQuery) { - self.changed_at = cycle_query.changed_at; - self.durability = cycle_query.durability; - self.input_outputs.clone_from(&cycle_query.input_outputs); - } - pub(super) fn disambiguate(&mut self, key: IdentityHash) -> Disambiguator { self.disambiguator_map.disambiguate(key) } diff --git a/src/cycle.rs b/src/cycle.rs index 8483a2857..b5a0554b5 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -1,109 +1,178 @@ -use crate::{key::DatabaseKeyIndex, Database}; -use std::{panic::AssertUnwindSafe, sync::Arc}; +//! Cycle handling +//! +//! Salsa's default cycle handling is quite simple: if we encounter a cycle (that is, if we attempt +//! to execute a query that is already on the active query stack), we panic. +//! +//! By setting `cycle_fn` and `cycle_initial` arguments to `salsa::tracked`, queries can opt-in to +//! fixed-point iteration instead. +//! +//! We call the query which triggers the cycle (that is, the query that is already on the stack +//! when it is called again) the "cycle head". The cycle head is responsible for managing iteration +//! of the cycle. When a cycle is encountered, if the cycle head has `cycle_fn` and `cycle_initial` +//! set, it will call the `cycle_initial` function to generate an "empty" or "initial" value for +//! fixed-point iteration, which will be returned to its caller. Then each query in the cycle will +//! compute a value normally, but every computed value will track the head(s) of the cycles it is +//! part of. Every query's "cycle heads" are the union of all the cycle heads of all the queries it +//! depends on. A memoized query result with cycle heads is called a "provisional value". +//! +//! For example, if `qa` calls `qb`, and `qb` calls `qc`, and `qc` calls `qa`, then `qa` will call +//! its `cycle_initial` function to get an initial value, and return that as its result to `qc`, +//! marked with `qa` as cycle head. `qc` will compute its own provisional result based on that, and +//! return to `qb` a result also marked with `qa` as cycle head. `qb` will similarly compute and +//! return a provisional value back to `qa`. +//! +//! When a query observes that it has just computed a result which contains itself as a cycle head, +//! it recognizes that it is responsible for resolving this cycle and calls its `cycle_fn` to +//! decide how to do so. The `cycle_fn` function is passed the provisional value just computed for +//! that query and the count of iterations so far, and must return either +//! `CycleRecoveryAction::Iterate` (which signals that the cycle head should re-iterate the cycle), +//! or `CycleRecoveryAction::Fallback` (which signals that the cycle head should replace its +//! computed value with the given fallback value). +//! +//! If the cycle head ever observes that the provisional value it just recomputed is the same as +//! the provisional value from the previous iteration, the cycle has converged. The cycle head will +//! mark that value as final (by removing itself as cycle head) and return it. +//! +//! Other queries in the cycle will still have provisional values recorded, but those values should +//! now also be considered final! We don't eagerly walk the entire cycle to mark them final. +//! Instead, we wait until the next time that provisional value is read, and then we check if all +//! of its cycle heads have a final result, in which case it, too, can be marked final. (This is +//! implemented in `shallow_verify_memo` and `validate_provisional`.) +//! +//! If the `cycle_fn` returns a fallback value, the cycle head will replace its provisional value +//! with that fallback, and then iterate the cycle one more time. A fallback value is expected to +//! result in a stable, converged cycle. If it does not (that is, if the result of another +//! iteration of the cycle is not the same as the fallback value), we'll panic. +//! +//! In nested cycle cases, the inner cycle head will iterate until its own cycle is resolved, but +//! the "final" value it then returns will still be provisional on the outer cycle head. The outer +//! cycle head may then iterate, which may result in a new set of iterations on the inner cycle, +//! for each iteration of the outer cycle. -/// Captures the participants of a cycle that occurred when executing a query. -/// -/// This type is meant to be used to help give meaningful error messages to the -/// user or to help salsa developers figure out why their program is resulting -/// in a computation cycle. -/// -/// It is used in a few ways: -/// -/// * During [cycle recovery](https://https://salsa-rs.github.io/salsa/cycles/fallback.html), -/// where it is given to the fallback function. -/// * As the panic value when an unexpected cycle (i.e., a cycle where one or more participants -/// lacks cycle recovery information) occurs. +use crate::key::DatabaseKeyIndex; +use rustc_hash::FxHashSet; + +/// The maximum number of times we'll fixpoint-iterate before panicking. /// -/// You can read more about cycle handling in -/// the [salsa book](https://https://salsa-rs.github.io/salsa/cycles.html). -#[derive(Clone, PartialEq, Eq, Hash)] -pub struct Cycle { - participants: CycleParticipants, +/// Should only be relevant in case of a badly configured cycle recovery. +pub const MAX_ITERATIONS: u32 = 200; + +/// Return value from a cycle recovery function. +#[derive(Debug)] +pub enum CycleRecoveryAction { + /// Iterate the cycle again to look for a fixpoint. + Iterate, + + /// Cut off iteration and use the given result value for this query. + Fallback(T), +} + +/// Cycle recovery strategy: Is this query capable of recovering from +/// a cycle that results from executing the function? If so, how? +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum CycleRecoveryStrategy { + /// Cannot recover from cycles: panic. + /// + /// This is the default. + Panic, + + /// Recovers from cycles by fixpoint iterating and/or falling + /// back to a sentinel value. + /// + /// This choice is computed by the query's `cycle_recovery` + /// function and initial value. + Fixpoint, } -// We want `Cycle`` to be thin -pub(crate) type CycleParticipants = Arc>; +/// A "cycle head" is the query at which we encounter a cycle; that is, if A -> B -> C -> A, then A +/// would be the cycle head. It returns an "initial value" when the cycle is encountered (if +/// fixpoint iteration is enabled for that query), and then is responsible for re-iterating the +/// cycle until it converges. Any provisional value generated by any query in the cycle will track +/// the cycle head(s) (can be plural in case of nested cycles) representing the cycles it is part +/// of. This struct tracks these cycle heads. +#[derive(Clone, Debug, Default)] +pub(crate) struct CycleHeads(Option>>); -impl Cycle { - pub(crate) fn new(participants: CycleParticipants) -> Self { - Self { participants } +impl CycleHeads { + pub(crate) fn is_empty(&self) -> bool { + // We ensure in `remove` and `extend` that we never have an empty hashset, we always use + // None to signify empty. + self.0.is_none() } - /// True if two `Cycle` values represent the same cycle. - pub(crate) fn is(&self, cycle: &Cycle) -> bool { - Arc::ptr_eq(&self.participants, &cycle.participants) + pub(crate) fn contains(&self, value: &DatabaseKeyIndex) -> bool { + self.0.as_ref().is_some_and(|heads| heads.contains(value)) } - pub(crate) fn throw(self) -> ! { - tracing::debug!("throwing cycle {:?}", self); - std::panic::resume_unwind(Box::new(self)) + pub(crate) fn remove(&mut self, value: &DatabaseKeyIndex) -> bool { + if let Some(cycle_heads) = self.0.as_mut() { + let found = cycle_heads.remove(value); + if found && cycle_heads.is_empty() { + self.0.take(); + } + found + } else { + false + } } +} - pub(crate) fn catch(execute: impl FnOnce() -> T) -> Result { - match std::panic::catch_unwind(AssertUnwindSafe(execute)) { - Ok(v) => Ok(v), - Err(err) => match err.downcast::() { - Ok(cycle) => Err(*cycle), - Err(other) => std::panic::resume_unwind(other), - }, +impl std::iter::Extend for CycleHeads { + fn extend>(&mut self, iter: T) { + let mut iter = iter.into_iter(); + if let Some(first) = iter.next() { + let heads = self.0.get_or_insert(Box::new(FxHashSet::default())); + heads.insert(first); + heads.extend(iter) } } +} + +impl std::iter::IntoIterator for CycleHeads { + type Item = DatabaseKeyIndex; + type IntoIter = std::collections::hash_set::IntoIter; - /// Iterate over the [`DatabaseKeyIndex`] for each query participating - /// in the cycle. The start point of this iteration within the cycle - /// is arbitrary but deterministic, but the ordering is otherwise determined - /// by the execution. - pub fn participant_keys(&self) -> impl Iterator + '_ { - self.participants.iter().copied() + fn into_iter(self) -> Self::IntoIter { + self.0.map(|heads| *heads).unwrap_or_default().into_iter() } +} + +// This type can be removed once MSRV is 1.83+ and we have Default for hashset iterators. +pub(crate) struct CycleHeadsIter<'a>( + Option>, +); + +impl Iterator for CycleHeadsIter<'_> { + type Item = DatabaseKeyIndex; - /// Returns a vector with the debug information for - /// all the participants in the cycle. - pub fn all_participants(&self, _db: &dyn Database) -> Vec { - self.participant_keys().collect() + fn next(&mut self) -> Option { + self.0.as_mut()?.next().copied() } - /// Returns a vector with the debug information for - /// those participants in the cycle that lacked recovery - /// information. - pub fn unexpected_participants(&self, db: &dyn Database) -> Vec { - self.participant_keys() - .filter(|&d| d.cycle_recovery_strategy(db) == CycleRecoveryStrategy::Panic) - .collect() + fn last(self) -> Option { + self.0?.last().copied() } } -impl std::fmt::Debug for Cycle { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - crate::attach::with_attached_database(|db| { - f.debug_struct("UnexpectedCycle") - .field("all_participants", &self.all_participants(db)) - .field("unexpected_participants", &self.unexpected_participants(db)) - .finish() - }) - .unwrap_or_else(|| { - f.debug_struct("Cycle") - .field("participants", &self.participants) - .finish() - }) +impl std::iter::FusedIterator for CycleHeadsIter<'_> {} + +impl<'a> std::iter::IntoIterator for &'a CycleHeads { + type Item = DatabaseKeyIndex; + type IntoIter = CycleHeadsIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + CycleHeadsIter(self.0.as_ref().map(|heads| heads.iter())) } } -/// Cycle recovery strategy: Is this query capable of recovering from -/// a cycle that results from executing the function? If so, how? -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum CycleRecoveryStrategy { - /// Cannot recover from cycles: panic. - /// - /// This is the default. - /// - /// In the case of a failure due to a cycle, the panic - /// value will be the `Cycle`. - Panic, - - /// Recovers from cycles by storing a sentinel value. - /// - /// This value is computed by the query's `recovery_fn` - /// function. - Fallback, +impl From> for CycleHeads { + fn from(value: FxHashSet) -> Self { + Self(if value.is_empty() { + None + } else { + Some(Box::new(value)) + }) + } } + +pub(crate) static EMPTY_CYCLE_HEADS: CycleHeads = CycleHeads(None); diff --git a/src/function.rs b/src/function.rs index b08045a8f..216f0fac9 100644 --- a/src/function.rs +++ b/src/function.rs @@ -2,22 +2,25 @@ use std::{any::Any, fmt, ptr::NonNull}; use crate::{ accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}, - cycle::CycleRecoveryStrategy, - ingredient::{fmt_index, MaybeChangedAfter}, + cycle::{CycleRecoveryAction, CycleRecoveryStrategy}, + ingredient::fmt_index, key::DatabaseKeyIndex, plumbing::MemoIngredientMap, salsa_struct::SalsaStructInDb, + table::sync::ClaimResult, table::Table, views::DatabaseDownCaster, zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}, zalsa_local::QueryOrigin, - Cycle, Database, Id, Revision, + Database, Id, Revision, }; use self::delete::DeletedEntries; use super::ingredient::Ingredient; +pub(crate) use maybe_changed_after::VerifyResult; + mod accumulated; mod backdate; mod delete; @@ -51,13 +54,12 @@ pub trait Configuration: Any { /// (and, if so, how). const CYCLE_STRATEGY: CycleRecoveryStrategy; - /// Invokes after a new result `new_value` has been computed for which an older memoized - /// value existed `old_value`. Returns true if the new value is equal to the older one - /// and hence should be "backdated" (i.e., marked as having last changed in an older revision, - /// even though it was recomputed). + /// Invokes after a new result `new_value`` has been computed for which an older memoized value + /// existed `old_value`, or in fixpoint iteration. Returns true if the new value is equal to + /// the older one. /// - /// This invokes user's code in form of the `Eq` impl. - fn should_backdate_value(old_value: &Self::Output<'_>, new_value: &Self::Output<'_>) -> bool; + /// This invokes user code in form of the `Eq` impl. + fn values_equal(old_value: &Self::Output<'_>, new_value: &Self::Output<'_>) -> bool; /// Convert from the id used internally to the value that execute is expecting. /// This is a no-op if the input to the function is a salsa struct. @@ -69,15 +71,18 @@ pub trait Configuration: Any { /// This invokes the function the user wrote. fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; - /// If the cycle strategy is `Fallback`, then invoked when `key` is a participant - /// in a cycle to find out what value it should have. - /// - /// This invokes the recovery function given by the user. + /// Get the cycle recovery initial value. + fn cycle_initial<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; + + /// Decide whether to iterate a cycle again or fallback. `value` is the provisional return + /// value from the latest iteration of this cycle. `count` is the number of cycle iterations + /// we've already completed. fn recover_from_cycle<'db>( db: &'db Self::DbView, - cycle: &Cycle, + value: &Self::Output<'db>, + count: u32, input: Self::Input<'db>, - ) -> Self::Output<'db>; + ) -> CycleRecoveryAction>; } /// Function ingredients are the "workhorse" of salsa. @@ -130,9 +135,9 @@ pub struct IngredientImpl { } /// True if `old_value == new_value`. Invoked by the generated -/// code for `should_backdate_value` so as to give a better +/// code for `values_equal` so as to give a better /// error message. -pub fn should_backdate_value(old_value: &V, new_value: &V) -> bool { +pub fn values_equal(old_value: &V, new_value: &V) -> bool { old_value == new_value } @@ -229,12 +234,37 @@ where db: &dyn Database, input: Id, revision: Revision, - ) -> MaybeChangedAfter { + ) -> VerifyResult { // SAFETY: The `db` belongs to the ingredient as per caller invariant let db = unsafe { self.view_caster.downcast_unchecked(db) }; self.maybe_changed_after(db, input, revision) } + /// True if the input `input` contains a memo that cites itself as a cycle head. + /// This indicates an intermediate value for a cycle that has not yet reached a fixed point. + fn is_provisional_cycle_head<'db>(&'db self, db: &'db dyn Database, input: Id) -> bool { + self.get_memo_from_table_for( + db.zalsa(), + input, + self.memo_ingredient_index(db.zalsa(), input), + ) + .is_some_and(|memo| memo.cycle_heads().contains(&self.database_key_index(input))) + } + + /// Attempts to claim `key_index`, returning `false` if a cycle occurs. + fn wait_for(&self, db: &dyn Database, key_index: Id) -> bool { + let zalsa = db.zalsa(); + match zalsa.sync_table_for(key_index).claim( + db, + zalsa, + self.database_key_index(key_index), + self.memo_ingredient_index(zalsa, key_index), + ) { + ClaimResult::Retry | ClaimResult::Claimed(_) => true, + ClaimResult::Cycle => false, + } + } + fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { C::CYCLE_STRATEGY } @@ -257,6 +287,7 @@ where _db: &dyn Database, _executor: DatabaseKeyIndex, _stale_output_key: crate::Id, + _provisional: bool, ) { // This function is invoked when a query Q specifies the value for `stale_output_key` in rev 1, // but not in rev 2. We don't do anything in this case, we just leave the (now stale) memo. diff --git a/src/function/backdate.rs b/src/function/backdate.rs index 57733d51a..9b6ee1906 100644 --- a/src/function/backdate.rs +++ b/src/function/backdate.rs @@ -21,7 +21,7 @@ where // consumers must be aware of. Becoming *more* durable // is not. See the test `durable_to_less_durable`. if revisions.durability >= old_memo.revisions.durability - && C::should_backdate_value(old_value, value) + && C::values_equal(old_value, value) { tracing::debug!( "value is equal, back-dating to {:?}", diff --git a/src/function/diff_outputs.rs b/src/function/diff_outputs.rs index 224b0434e..03b243529 100644 --- a/src/function/diff_outputs.rs +++ b/src/function/diff_outputs.rs @@ -8,11 +8,14 @@ impl IngredientImpl where C: Configuration, { - /// Compute the old and new outputs and invoke the `clear_stale_output` callback + /// Compute the old and new outputs and invoke `remove_stale_output` /// for each output that was generated before but is not generated now. /// /// This function takes a `&mut` reference to `revisions` to remove outputs /// that no longer exist in this revision from [`QueryRevisions::tracked_struct_ids`]. + /// + /// If `provisional` is true, the new outputs are from a cycle-provisional result. In + /// that case, we won't panic if we see outputs from the current revision become stale. pub(super) fn diff_outputs( &self, zalsa: &Zalsa, @@ -20,6 +23,7 @@ where key: DatabaseKeyIndex, old_memo: &Memo>, revisions: &mut QueryRevisions, + provisional: bool, ) { // Iterate over the outputs of the `old_memo` and put them into a hashset let mut old_outputs: FxHashSet<_> = old_memo.revisions.origin.outputs().collect(); @@ -39,7 +43,7 @@ where } for old_output in old_outputs { - Self::report_stale_output(zalsa, db, key, old_output); + Self::report_stale_output(zalsa, db, key, old_output, provisional); } } @@ -48,6 +52,7 @@ where db: &C::DbView, key: DatabaseKeyIndex, output: OutputDependencyIndex, + provisional: bool, ) { db.salsa_event(&|| { Event::new(EventKind::WillDiscardStaleOutput { @@ -55,7 +60,6 @@ where output_key: output, }) }); - - output.remove_stale_output(zalsa, db.as_dyn_database(), key); + output.remove_stale_output(zalsa, db.as_dyn_database(), key, provisional); } } diff --git a/src/function/execute.rs b/src/function/execute.rs index c24c9ea01..87844d299 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -1,5 +1,8 @@ use crate::{ - zalsa::ZalsaDatabase, zalsa_local::ActiveQueryGuard, Cycle, Database, Event, EventKind, + cycle::{CycleRecoveryStrategy, MAX_ITERATIONS}, + zalsa::ZalsaDatabase, + zalsa_local::ActiveQueryGuard, + Database, Event, EventKind, }; use super::{memo::Memo, Configuration, IngredientImpl}; @@ -20,12 +23,13 @@ where pub(super) fn execute<'db>( &'db self, db: &'db C::DbView, - active_query: ActiveQueryGuard<'_>, + mut active_query: ActiveQueryGuard<'db>, opt_old_memo: Option<&Memo>>, ) -> &'db Memo> { - let zalsa = db.zalsa(); + let (zalsa, zalsa_local) = db.zalsas(); let revision_now = zalsa.current_revision(); let database_key_index = active_query.database_key_index; + let id = database_key_index.key_index; tracing::info!("{:?}: executing query", database_key_index); @@ -35,59 +39,148 @@ where }) }); - // If we already executed this query once, then use the tracked-struct ids from the - // previous execution as the starting point for the new one. - if let Some(old_memo) = opt_old_memo { - active_query.seed_tracked_struct_ids(&old_memo.revisions.tracked_struct_ids); - } + let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); - // Query was not previously executed, or value is potentially - // stale, or value is absent. Let's execute! - let database_key_index = active_query.database_key_index; - let id = database_key_index.key_index; - let value = match Cycle::catch(|| C::execute(db, C::id_to_input(db, id))) { - Ok(v) => v, - Err(cycle) => { + let mut iteration_count: u32 = 0; + let mut fell_back = false; + + // Our provisional value from the previous iteration, when doing fixpoint iteration. + // Initially it's set to None, because the initial provisional value is created lazily, + // only when a cycle is actually encountered. + let mut opt_last_provisional: Option<&Memo<::Output<'db>>> = None; + + loop { + // If we already executed this query once, then use the tracked-struct ids from the + // previous execution as the starting point for the new one. + if let Some(old_memo) = opt_old_memo { + active_query.seed_tracked_struct_ids(&old_memo.revisions.tracked_struct_ids); + } + + // Query was not previously executed, or value is potentially + // stale, or value is absent. Let's execute! + let mut new_value = C::execute(db, C::id_to_input(db, id)); + let mut revisions = active_query.pop(); + + // Did the new result we got depend on our own provisional value, in a cycle? + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::Fixpoint + && revisions.cycle_heads.contains(&database_key_index) + { + let opt_owned_last_provisional; + let last_provisional_value = if let Some(last_provisional) = opt_last_provisional { + // We have a last provisional value from our previous time around the loop. + last_provisional + .value + .as_ref() + .expect("provisional value should not be evicted by LRU") + } else { + // This is our first time around the loop; a provisional value must have been + // inserted into the memo table when the cycle was hit, so let's pull our + // initial provisional value from there. + opt_owned_last_provisional = + self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); + debug_assert!(opt_owned_last_provisional + .as_ref() + .unwrap() + .may_be_provisional()); + opt_owned_last_provisional + .expect( + "{database_key_index:#?} is a cycle head, \ + but no provisional memo found", + ) + .value + .as_ref() + .expect("provisional value should not be evicted by LRU") + }; tracing::debug!( - "{database_key_index:?}: caught cycle {cycle:?}, have strategy {:?}", - C::CYCLE_STRATEGY + "{database_key_index:?}: execute: \ + I am a cycle head, comparing last provisional value with new value" ); - match C::CYCLE_STRATEGY { - crate::cycle::CycleRecoveryStrategy::Panic => cycle.throw(), - crate::cycle::CycleRecoveryStrategy::Fallback => { - if let Some(c) = active_query.take_cycle() { - assert!(c.is(&cycle)); - C::recover_from_cycle(db, &cycle, C::id_to_input(db, id)) - } else { - // we are not a participant in this cycle - debug_assert!(!cycle - .participant_keys() - .any(|k| k == database_key_index)); - cycle.throw() + // If the new result is equal to the last provisional result, the cycle has + // converged and we are done. + if !C::values_equal(&new_value, last_provisional_value) { + if fell_back { + // We fell back to a value last iteration, but the fallback didn't result + // in convergence. We only have bad options here: continue iterating + // (ignoring the request to fall back), or forcibly use the fallback and + // leave the cycle in an inconsistent state (we'll be using a value for + // this query that it doesn't evaluate to, given its inputs). Maybe we'll + // have to go with the latter, but for now let's panic and see if real use + // cases need non-converging fallbacks. + panic!("{database_key_index:?}: execute: fallback did not converge"); + } + // We are in a cycle that hasn't converged; ask the user's + // cycle-recovery function what to do: + match C::recover_from_cycle( + db, + &new_value, + iteration_count, + C::id_to_input(db, id), + ) { + crate::CycleRecoveryAction::Iterate => { + tracing::debug!("{database_key_index:?}: execute: iterate again"); + } + crate::CycleRecoveryAction::Fallback(fallback_value) => { + tracing::debug!( + "{database_key_index:?}: execute: user cycle_fn says to fall back" + ); + new_value = fallback_value; + // We have to insert the fallback value for this query and then iterate + // one more time to fill in correct values for everything else in the + // cycle based on it; then we'll re-insert it as final value. + fell_back = true; } } + iteration_count = iteration_count + .checked_add(1) + .expect("fixpoint iteration should converge before u32::MAX iterations"); + if iteration_count > MAX_ITERATIONS { + panic!("{database_key_index:?}: execute: too many cycle iterations"); + } + opt_last_provisional = Some(self.insert_memo( + zalsa, + id, + Memo::new(Some(new_value), revision_now, revisions), + memo_ingredient_index, + )); + + active_query = zalsa_local.push_query(database_key_index); + + continue; } + tracing::debug!( + "{database_key_index:?}: execute: fixpoint iteration has a final value" + ); + revisions.cycle_heads.remove(&database_key_index); } - }; - let mut revisions = active_query.pop(); - - // If the new value is equal to the old one, then it didn't - // really change, even if some of its inputs have. So we can - // "backdate" its `changed_at` revision to be the same as the - // old value. - if let Some(old_memo) = opt_old_memo { - self.backdate_if_appropriate(old_memo, &mut revisions, &value); - self.diff_outputs(zalsa, db, database_key_index, old_memo, &mut revisions); - } - tracing::debug!("{database_key_index:?}: read_upgrade: result.revisions = {revisions:#?}"); + tracing::debug!("{database_key_index:?}: execute: result.revisions = {revisions:#?}"); - let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); - self.insert_memo( - zalsa, - id, - Memo::new(Some(value), revision_now, revisions), - memo_ingredient_index, - ) + if let Some(old_memo) = opt_old_memo { + // If the new value is equal to the old one, then it didn't + // really change, even if some of its inputs have. So we can + // "backdate" its `changed_at` revision to be the same as the + // old value. + self.backdate_if_appropriate(old_memo, &mut revisions, &new_value); + + // Diff the new outputs with the old, to discard any no-longer-emitted + // outputs and update the tracked struct IDs for seeding the next revision. + let provisional = !revisions.cycle_heads.is_empty(); + self.diff_outputs( + zalsa, + db, + database_key_index, + old_memo, + &mut revisions, + provisional, + ); + } + + return self.insert_memo( + zalsa, + id, + Memo::new(Some(new_value), revision_now, revisions), + memo_ingredient_index, + ); + } } } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 9f687ab17..793301438 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -1,10 +1,12 @@ -use super::{memo::Memo, Configuration, IngredientImpl}; +use super::{memo::Memo, Configuration, IngredientImpl, VerifyResult}; use crate::zalsa::MemoIngredientIndex; use crate::{ accumulator::accumulated_map::InputAccumulatedValues, runtime::StampedValue, + table::sync::ClaimResult, zalsa::{Zalsa, ZalsaDatabase}, - Id, + zalsa_local::QueryRevisions, + AsDynDatabase as _, Id, }; impl IngredientImpl @@ -35,6 +37,7 @@ where Some(_) => InputAccumulatedValues::Any, None => memo.revisions.accumulated_inputs.load(), }, + memo.cycle_heads(), ); value @@ -53,7 +56,16 @@ where .fetch_hot(zalsa, db, id, memo_ingredient_index) .or_else(|| self.fetch_cold(zalsa, db, id, memo_ingredient_index)) { - return memo; + // If we get back a provisional cycle memo, and it's provisional on any cycle heads + // that are claimed by a different thread, we can't propagate the provisional memo + // any further (it could escape outside the cycle); we need to block on the other + // thread completing fixpoint iteration of the cycle, and then we can re-query for + // our no-longer-provisional memo. + if !(memo.may_be_provisional() + && memo.provisional_retry(db.as_dyn_database(), self.database_key_index(id))) + { + return memo; + } } } } @@ -69,7 +81,7 @@ where let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); if let Some(memo) = memo_guard { if memo.value.is_some() - && self.shallow_verify_memo(db, zalsa, self.database_key_index(id), memo) + && self.shallow_verify_memo(db, zalsa, self.database_key_index(id), memo, false) { // Unsafety invariant: memo is present in memo_map and we have verified that it is // still valid for the current revision. @@ -89,10 +101,58 @@ where let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. - let _claim_guard = - zalsa - .sync_table_for(id) - .claim(db, zalsa, database_key_index, memo_ingredient_index)?; + let _claim_guard = match zalsa.sync_table_for(id).claim( + db, + zalsa, + database_key_index, + memo_ingredient_index, + ) { + ClaimResult::Retry => return None, + ClaimResult::Cycle => { + // check if there's a provisional value for this query + let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); + if let Some(memo) = &memo_guard { + if memo.value.is_some() + && memo.revisions.cycle_heads.contains(&database_key_index) + && self.shallow_verify_memo(db, zalsa, database_key_index, memo, true) + { + // Unsafety invariant: memo is present in memo_map. + unsafe { + return Some(self.extend_memo_lifetime(memo)); + } + } + } + // no provisional value; create/insert/return initial provisional value + return self + .initial_value(db, database_key_index.key_index) + .map(|initial_value| { + tracing::debug!( + "hit cycle at {database_key_index:#?}, \ + inserting and returning fixpoint initial value" + ); + self.insert_memo( + zalsa, + id, + Memo::new( + Some(initial_value), + zalsa.current_revision(), + QueryRevisions::fixpoint_initial( + database_key_index, + zalsa.current_revision(), + ), + ), + memo_ingredient_index, + ) + }) + .or_else(|| { + panic!( + "dependency graph cycle querying {database_key_index:#?}; \ + set cycle_fn/cycle_initial to fixpoint iterate" + ) + }); + } + ClaimResult::Claimed(guard) => guard, + }; // Push the query on the stack. let active_query = db.zalsa_local().push_query(database_key_index); @@ -100,14 +160,21 @@ where // Now that we've claimed the item, check again to see if there's a "hot" value. let opt_old_memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); if let Some(old_memo) = opt_old_memo { - if old_memo.value.is_some() && self.deep_verify_memo(db, zalsa, old_memo, &active_query) - { - // Unsafety invariant: memo is present in memo_map and we have verified that it is - // still valid for the current revision. - return unsafe { Some(self.extend_memo_lifetime(old_memo)) }; + if old_memo.value.is_some() { + if let VerifyResult::Unchanged(_, cycle_heads) = + self.deep_verify_memo(db, zalsa, old_memo, &active_query) + { + if cycle_heads.is_empty() { + // Unsafety invariant: memo is present in memo_map and we have verified that it is + // still valid for the current revision. + return unsafe { Some(self.extend_memo_lifetime(old_memo)) }; + } + } } } - Some(self.execute(db, active_query, opt_old_memo)) + let memo = self.execute(db, active_query, opt_old_memo); + + Some(memo) } } diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 762455ad4..a46343f4e 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -1,14 +1,46 @@ use crate::{ accumulator::accumulated_map::InputAccumulatedValues, - ingredient::MaybeChangedAfter, + cycle::CycleRecoveryStrategy, key::DatabaseKeyIndex, + table::sync::ClaimResult, zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase}, zalsa_local::{ActiveQueryGuard, QueryEdge, QueryOrigin}, AsDynDatabase as _, Id, Revision, }; +use rustc_hash::FxHashSet; +use std::sync::atomic::Ordering; use super::{memo::Memo, Configuration, IngredientImpl}; +/// Result of memo validation. +pub enum VerifyResult { + /// Memo has changed and needs to be recomputed. + Changed, + + /// Memo remains valid. + /// + /// The first inner value tracks whether the memo or any of its dependencies have an + /// accumulated value. + /// + /// Database keys in the hashset represent cycle heads encountered in validation; don't mark + /// memos verified until we've iterated the full cycle to ensure no inputs changed. + Unchanged(InputAccumulatedValues, FxHashSet), +} + +impl VerifyResult { + pub(crate) fn changed_if(changed: bool) -> Self { + if changed { + Self::Changed + } else { + Self::unchanged() + } + } + + pub(crate) fn unchanged() -> Self { + Self::Unchanged(InputAccumulatedValues::Empty, FxHashSet::default()) + } +} + impl IngredientImpl where C: Configuration, @@ -18,7 +50,7 @@ where db: &'db C::DbView, id: Id, revision: Revision, - ) -> MaybeChangedAfter { + ) -> VerifyResult { let zalsa = db.zalsa(); let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); zalsa.unwind_if_revision_cancelled(db); @@ -31,11 +63,14 @@ where // Check if we have a verified version: this is the hot path. let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); if let Some(memo) = memo_guard { - if self.shallow_verify_memo(db, zalsa, database_key_index, memo) { + if self.shallow_verify_memo(db, zalsa, database_key_index, memo, false) { return if memo.revisions.changed_at > revision { - MaybeChangedAfter::Yes + VerifyResult::Changed } else { - MaybeChangedAfter::No(memo.revisions.accumulated_inputs.load()) + VerifyResult::Unchanged( + memo.revisions.accumulated_inputs.load(), + FxHashSet::default(), + ) }; } if let Some(mcs) = @@ -47,7 +82,7 @@ where } } else { // No memo? Assume has changed. - return MaybeChangedAfter::Yes; + return VerifyResult::Changed; } } } @@ -59,21 +94,34 @@ where key_index: Id, revision: Revision, memo_ingredient_index: MemoIngredientIndex, - ) -> Option { + ) -> Option { let database_key_index = self.database_key_index(key_index); - let _claim_guard = zalsa.sync_table_for(key_index).claim( + let _claim_guard = match zalsa.sync_table_for(key_index).claim( db, zalsa, database_key_index, memo_ingredient_index, - )?; - let active_query = db.zalsa_local().push_query(database_key_index); - + ) { + ClaimResult::Retry => return None, + ClaimResult::Cycle => match C::CYCLE_STRATEGY { + CycleRecoveryStrategy::Panic => panic!( + "dependency graph cycle validating {database_key_index:#?}; \ + set cycle_fn/cycle_initial to fixpoint iterate" + ), + CycleRecoveryStrategy::Fixpoint => { + return Some(VerifyResult::Unchanged( + InputAccumulatedValues::Empty, + FxHashSet::from_iter([database_key_index]), + )); + } + }, + ClaimResult::Claimed(guard) => guard, + }; // Load the current memo, if any. let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index, memo_ingredient_index) else { - return Some(MaybeChangedAfter::Yes); + return Some(VerifyResult::Changed); }; tracing::debug!( @@ -83,11 +131,14 @@ where ); // Check if the inputs are still valid. We can just compare `changed_at`. - if self.deep_verify_memo(db, zalsa, old_memo, &active_query) { + let active_query = db.zalsa_local().push_query(database_key_index); + if let VerifyResult::Unchanged(_, cycle_heads) = + self.deep_verify_memo(db, zalsa, old_memo, &active_query) + { return Some(if old_memo.revisions.changed_at > revision { - MaybeChangedAfter::Yes + VerifyResult::Changed } else { - MaybeChangedAfter::No(old_memo.revisions.accumulated_inputs.load()) + VerifyResult::Unchanged(old_memo.revisions.accumulated_inputs.load(), cycle_heads) }); } @@ -100,21 +151,34 @@ where let changed_at = memo.revisions.changed_at; return Some(if changed_at > revision { - MaybeChangedAfter::Yes + VerifyResult::Changed } else { - MaybeChangedAfter::No(match &memo.revisions.accumulated { - Some(_) => InputAccumulatedValues::Any, - None => memo.revisions.accumulated_inputs.load(), - }) + VerifyResult::Unchanged( + match &memo.revisions.accumulated { + Some(_) => InputAccumulatedValues::Any, + None => memo.revisions.accumulated_inputs.load(), + }, + FxHashSet::default(), + ) }); } // Otherwise, nothing for it: have to consider the value to have changed. - Some(MaybeChangedAfter::Yes) + Some(VerifyResult::Changed) } /// True if the memo's value and `changed_at` time is still valid in this revision. /// Does only a shallow O(1) check, doesn't walk the dependencies. + /// + /// In general, a provisional memo (from cycle iteration) does not verify. Since we don't + /// eagerly finalize all provisional memos in cycle iteration, we have to lazily check here + /// (via `validate_provisional`) whether a may-be-provisional memo should actually be verified + /// final, because its cycle heads are all now final. + /// + /// If `allow_provisional` is `true`, don't check provisionality and return whatever memo we + /// find that can be verified in this revision, whether provisional or not. This only occurs at + /// one call-site, in `fetch_cold` when we actually encounter a cycle, and want to check if + /// there is an existing provisional memo we can reuse. #[inline] pub(super) fn shallow_verify_memo( &self, @@ -122,14 +186,23 @@ where zalsa: &Zalsa, database_key_index: DatabaseKeyIndex, memo: &Memo>, + allow_provisional: bool, ) -> bool { - let verified_at = memo.verified_at.load(); - let revision_now = zalsa.current_revision(); - tracing::debug!( "{database_key_index:?}: shallow_verify_memo(memo = {memo:#?})", memo = memo.tracing_debug() ); + if !allow_provisional && memo.may_be_provisional() { + tracing::debug!( + "{database_key_index:?}: validate_provisional(memo = {memo:#?})", + memo = memo.tracing_debug() + ); + if !self.validate_provisional(db, zalsa, memo) { + return false; + } + } + let verified_at = memo.verified_at.load(); + let revision_now = zalsa.current_revision(); if verified_at == revision_now { // Already verified. @@ -159,21 +232,42 @@ where false } - /// True if the memo's value and `changed_at` time is up-to-date in the current - /// revision. When this returns true, it also updates the memo's `verified_at` - /// field if needed to make future calls cheaper. + /// Check if this memo's cycle heads have all been finalized. If so, mark it verified final and + /// return true, if not return false. + fn validate_provisional( + &self, + db: &C::DbView, + zalsa: &Zalsa, + memo: &Memo>, + ) -> bool { + for cycle_head in &memo.revisions.cycle_heads { + if zalsa + .lookup_ingredient(cycle_head.ingredient_index) + .is_provisional_cycle_head(db.as_dyn_database(), cycle_head.key_index) + { + return false; + } + } + // Relaxed is sufficient here because there are no other writes we need to ensure have + // happened before marking this memo as verified-final. + memo.verified_final.store(true, Ordering::Relaxed); + true + } + + /// VerifyResult::Unchanged if the memo's value and `changed_at` time is up-to-date in the + /// current revision. When this returns Unchanged with no cycle heads, it also updates the + /// memo's `verified_at` field if needed to make future calls cheaper. /// /// Takes an [`ActiveQueryGuard`] argument because this function recursively /// walks dependencies of `old_memo` and may even execute them to see if their - /// outputs have changed. As that could lead to cycles, it is important that the - /// query is on the stack. + /// outputs have changed. pub(super) fn deep_verify_memo( &self, db: &C::DbView, zalsa: &Zalsa, old_memo: &Memo>, active_query: &ActiveQueryGuard<'_>, - ) -> bool { + ) -> VerifyResult { let database_key_index = active_query.database_key_index; tracing::debug!( @@ -181,78 +275,130 @@ where old_memo = old_memo.tracing_debug() ); - if self.shallow_verify_memo(db, zalsa, database_key_index, old_memo) { - return true; + if self.shallow_verify_memo(db, zalsa, database_key_index, old_memo, false) { + return VerifyResult::Unchanged(InputAccumulatedValues::Empty, Default::default()); + } + if old_memo.may_be_provisional() { + return VerifyResult::Changed; } - match &old_memo.revisions.origin { - QueryOrigin::Assigned(_) => { - // If the value was assigneed by another query, - // and that query were up-to-date, - // then we would have updated the `verified_at` field already. - // So the fact that we are here means that it was not specified - // during this revision or is otherwise stale. - // - // Example of how this can happen: - // - // Conditionally specified queries - // where the value is specified - // in rev 1 but not in rev 2. - false - } - QueryOrigin::DerivedUntracked(_) => { - // Untracked inputs? Have to assume that it changed. - false - } - QueryOrigin::Derived(edges) => { - // Fully tracked inputs? Iterate over the inputs and check them, one by one. - // - // NB: It's important here that we are iterating the inputs in the order that - // they executed. It's possible that if the value of some input I0 is no longer - // valid, then some later input I1 might never have executed at all, so verifying - // it is still up to date is meaningless. - let last_verified_at = old_memo.verified_at.load(); - let mut inputs = InputAccumulatedValues::Empty; - let dyn_db = db.as_dyn_database(); - for &edge in edges.input_outputs.iter() { - match edge { - QueryEdge::Input(dependency_index) => { - match dependency_index.maybe_changed_after(dyn_db, last_verified_at) { - MaybeChangedAfter::Yes => return false, - MaybeChangedAfter::No(input_accumulated) => { - inputs |= input_accumulated; + let mut cycle_heads = FxHashSet::default(); + loop { + let inputs = match &old_memo.revisions.origin { + QueryOrigin::Assigned(_) => { + // If the value was assigneed by another query, + // and that query were up-to-date, + // then we would have updated the `verified_at` field already. + // So the fact that we are here means that it was not specified + // during this revision or is otherwise stale. + // + // Example of how this can happen: + // + // Conditionally specified queries + // where the value is specified + // in rev 1 but not in rev 2. + return VerifyResult::Changed; + } + QueryOrigin::FixpointInitial => { + return VerifyResult::unchanged(); + } + QueryOrigin::DerivedUntracked(_) => { + // Untracked inputs? Have to assume that it changed. + return VerifyResult::Changed; + } + QueryOrigin::Derived(edges) => { + // Fully tracked inputs? Iterate over the inputs and check them, one by one. + // + // NB: It's important here that we are iterating the inputs in the order that + // they executed. It's possible that if the value of some input I0 is no longer + // valid, then some later input I1 might never have executed at all, so verifying + // it is still up to date is meaningless. + let last_verified_at = old_memo.verified_at.load(); + let mut inputs = InputAccumulatedValues::Empty; + let dyn_db = db.as_dyn_database(); + for &edge in edges.input_outputs.iter() { + match edge { + QueryEdge::Input(dependency_index) => { + match dependency_index + .maybe_changed_after(db.as_dyn_database(), last_verified_at) + { + VerifyResult::Changed => return VerifyResult::Changed, + VerifyResult::Unchanged(input_accumulated, cycles) => { + cycle_heads.extend(cycles); + inputs |= input_accumulated; + } } } - } - QueryEdge::Output(dependency_index) => { - // Subtle: Mark outputs as validated now, even though we may - // later find an input that requires us to re-execute the function. - // Even if it re-execute, the function will wind up writing the same value, - // since all prior inputs were green. It's important to do this during - // this loop, because it's possible that one of our input queries will - // re-execute and may read one of our earlier outputs - // (e.g., in a scenario where we do something like - // `e = Entity::new(..); query(e);` and `query` reads a field of `e`). - // - // NB. Accumulators are also outputs, but the above logic doesn't - // quite apply to them. Since multiple values are pushed, the first value - // may be unchanged, but later values could be different. - // In that case, however, the data accumulated - // by this function cannot be read until this function is marked green, - // so even if we mark them as valid here, the function will re-execute - // and overwrite the contents. - dependency_index.mark_validated_output( - zalsa, - dyn_db, - database_key_index, - ); + QueryEdge::Output(dependency_index) => { + // Subtle: Mark outputs as validated now, even though we may + // later find an input that requires us to re-execute the function. + // Even if it re-execute, the function will wind up writing the same value, + // since all prior inputs were green. It's important to do this during + // this loop, because it's possible that one of our input queries will + // re-execute and may read one of our earlier outputs + // (e.g., in a scenario where we do something like + // `e = Entity::new(..); query(e);` and `query` reads a field of `e`). + // + // NB. Accumulators are also outputs, but the above logic doesn't + // quite apply to them. Since multiple values are pushed, the first value + // may be unchanged, but later values could be different. + // In that case, however, the data accumulated + // by this function cannot be read until this function is marked green, + // so even if we mark them as valid here, the function will re-execute + // and overwrite the contents. + dependency_index.mark_validated_output( + zalsa, + dyn_db, + database_key_index, + ); + } } } + inputs } + }; + + // Possible scenarios here: + // + // 1. Cycle heads is empty. We traversed our full dependency graph and neither hit any + // cycles, nor found any changed dependencies. We can mark our memo verified and + // return Unchanged with empty cycle heads. + // + // 2. Cycle heads is non-empty, and does not contain our own key index. We are part of + // a cycle, and since we don't know if some other cycle participant that hasn't been + // traversed yet (that is, some other dependency of the cycle head, which is only a + // dependency of ours via the cycle) might still have changed, we can't yet mark our + // memo verified. We can return a provisional Unchanged, with cycle heads. + // + // 3. Cycle heads is non-empty, and contains only our own key index. We are the head of + // a cycle, and we've now traversed the entire cycle and found no changes, but no + // other cycle participants were verified (they would have all hit case 2 above). We + // can now safely mark our own memo as verified. Then we have to traverse the entire + // cycle again. This time, since our own memo is verified, there will be no cycle + // encountered, and the rest of the cycle will be able to verify itself. + // + // 4. Cycle heads is non-empty, and contains our own key index as well as other key + // indices. We are the head of a cycle nested within another cycle. We can't mark + // our own memo verified (for the same reason as in case 2: the full outer cycle + // hasn't been validated unchanged yet). We return Unchanged, with ourself removed + // from cycle heads. We will handle our own memo (and the rest of our cycle) on a + // future iteration; first the outer cycle head needs to verify itself. + let in_heads = cycle_heads.remove(&database_key_index); + + if cycle_heads.is_empty() { old_memo.mark_as_verified(db, zalsa.current_revision(), database_key_index, inputs); - true + + if in_heads { + // Iterate our dependency graph again, starting from the top. We clear the + // cycle heads here because we are starting a fresh traversal. (It might be + // logically clearer to create a new HashSet each time, but clearing the + // existing one is more efficient.) + cycle_heads.clear(); + continue; + } } + return VerifyResult::Unchanged(InputAccumulatedValues::Empty, cycle_heads); } } } diff --git a/src/function/memo.rs b/src/function/memo.rs index 844da0d5f..390e4d587 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -2,6 +2,7 @@ use std::any::Any; use std::fmt::Debug; use std::fmt::Formatter; use std::ptr::NonNull; +use std::sync::atomic::{AtomicBool, Ordering}; use crate::accumulator::accumulated_map::InputAccumulatedValues; use crate::revision::AtomicRevision; @@ -9,8 +10,11 @@ use crate::table::memo::MemoTable; use crate::zalsa::MemoIngredientIndex; use crate::zalsa_local::QueryOrigin; use crate::{ - key::DatabaseKeyIndex, zalsa::Zalsa, zalsa_local::QueryRevisions, Event, EventKind, Id, - Revision, + cycle::{CycleHeads, CycleRecoveryStrategy, EMPTY_CYCLE_HEADS}, + key::DatabaseKeyIndex, + zalsa::Zalsa, + zalsa_local::QueryRevisions, + Event, EventKind, Id, Revision, }; use super::{Configuration, IngredientImpl}; @@ -82,7 +86,7 @@ impl IngredientImpl { } /// Evicts the existing memo for the given key, replacing it - /// with an equivalent memo that has no value. If the memo is untracked, BaseInput, + /// with an equivalent memo that has no value. If the memo is untracked, FixpointInitial, /// or has values assigned as output of another query, this has no effect. pub(super) fn evict_value_from_memo_for( table: &mut MemoTable, @@ -90,7 +94,9 @@ impl IngredientImpl { ) { let map = |memo: &mut Memo>| { match &memo.revisions.origin { - QueryOrigin::Assigned(_) | QueryOrigin::DerivedUntracked(_) => { + QueryOrigin::Assigned(_) + | QueryOrigin::DerivedUntracked(_) + | QueryOrigin::FixpointInitial => { // Careful: Cannot evict memos whose values were // assigned as output of another query // or those with untracked inputs @@ -105,6 +111,17 @@ impl IngredientImpl { table.map_memo(memo_ingredient_index, map) } + + pub(super) fn initial_value<'db>( + &'db self, + db: &'db C::DbView, + key: Id, + ) -> Option> { + match C::CYCLE_STRATEGY { + CycleRecoveryStrategy::Fixpoint => Some(C::cycle_initial(db, C::id_to_input(db, key))), + CycleRecoveryStrategy::Panic => None, + } + } } #[derive(Debug)] @@ -116,6 +133,9 @@ pub(super) struct Memo { /// as the current revision. pub(super) verified_at: AtomicRevision, + /// Is this memo verified to not be a provisional cycle result? + pub(super) verified_final: AtomicBool, + /// Revision information pub(super) revisions: QueryRevisions, } @@ -123,17 +143,82 @@ pub(super) struct Memo { // Memo's are stored a lot, make sure their size is doesn't randomly increase. // #[cfg(test)] const _: [(); std::mem::size_of::>()] = - [(); std::mem::size_of::<[usize; 12]>()]; + [(); std::mem::size_of::<[usize; 14]>()]; impl Memo { pub(super) fn new(value: Option, revision_now: Revision, revisions: QueryRevisions) -> Self { Memo { value, verified_at: AtomicRevision::from(revision_now), + verified_final: AtomicBool::new(revisions.cycle_heads.is_empty()), revisions, } } + /// True if this may be a provisional cycle-iteration result. + #[inline] + pub(super) fn may_be_provisional(&self) -> bool { + // Relaxed is OK here, because `verified_final` is only ever mutated in one direction (from + // `false` to `true`), and changing it to `true` on memos with cycle heads where it was + // ever `false` is purely an optimization; if we read an out-of-date `false`, it just means + // we might go validate it again unnecessarily. + !self.verified_final.load(Ordering::Relaxed) + } + + /// Invoked when `refresh_memo` is about to return a memo to the caller; if that memo is + /// provisional, and its cycle head is claimed by another thread, we need to wait for that + /// other thread to complete the fixpoint iteration, and then retry fetching our own memo. + /// + /// Return `true` if the caller should retry, `false` if the caller should go ahead and return + /// this memo to the caller. + pub(super) fn provisional_retry( + &self, + db: &dyn crate::Database, + database_key_index: DatabaseKeyIndex, + ) -> bool { + let mut retry = false; + for head in self.cycle_heads() { + if head == database_key_index { + continue; + } + let ingredient = db.zalsa().lookup_ingredient(head.ingredient_index); + if !ingredient.is_provisional_cycle_head(db, head.key_index) { + // This cycle is already finalized, so we don't need to wait on it; + // keep looping through cycle heads. + retry = true; + continue; + } + if ingredient.wait_for(db, head.key_index) { + // There's a new memo available for the cycle head; fetch our own + // updated memo and see if it's still provisional or if the cycle + // has resolved. + retry = true; + continue; + } else { + // We hit a cycle blocking on the cycle head; this means it's in + // our own active query stack and we are responsible to resolve the + // cycle, so go ahead and return the provisional memo. + return false; + } + } + // If `retry` is `true`, all our cycle heads (barring ourself) are complete; re-fetch + // and we should get a non-provisional memo. If we get here and `retry` is still + // `false`, we have no cycle heads other than ourself, so we are a provisional value of + // the cycle head (either initial value, or from a later iteration) and should be + // returned to caller to allow fixpoint iteration to proceed. (All cases in the loop + // above other than "cycle head is self" are either terminal or set `retry`.) + retry + } + + /// Cycle heads that should be propagated to dependent queries. + pub(super) fn cycle_heads(&self) -> &CycleHeads { + if self.may_be_provisional() { + &self.revisions.cycle_heads + } else { + &EMPTY_CYCLE_HEADS + } + } + /// Mark memo as having been verified in the `revision_now`, which should /// be the current revision. pub(super) fn mark_as_verified( @@ -181,6 +266,7 @@ impl Memo { }, ) .field("verified_at", &self.memo.verified_at) + .field("verified_final", &self.memo.verified_final) .field("revisions", &self.memo.revisions) .finish() } diff --git a/src/function/specify.rs b/src/function/specify.rs index 4338e9157..ac0f58fcb 100644 --- a/src/function/specify.rs +++ b/src/function/specify.rs @@ -1,3 +1,5 @@ +use std::sync::atomic::AtomicBool; + use crate::{ accumulator::accumulated_map::InputAccumulatedValues, revision::AtomicRevision, @@ -71,17 +73,26 @@ where tracked_struct_ids: Default::default(), accumulated: Default::default(), accumulated_inputs: Default::default(), + cycle_heads: Default::default(), }; let memo_ingredient_index = self.memo_ingredient_index(zalsa, key); if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) { self.backdate_if_appropriate(old_memo, &mut revisions, &value); - self.diff_outputs(zalsa, db, database_key_index, old_memo, &mut revisions); + self.diff_outputs( + zalsa, + db, + database_key_index, + old_memo, + &mut revisions, + false, + ); } let memo = Memo { value: Some(value), verified_at: AtomicRevision::from(revision), + verified_final: AtomicBool::new(true), revisions, }; diff --git a/src/ingredient.rs b/src/ingredient.rs index dba70f644..0dcbb1d52 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -6,6 +6,7 @@ use std::{ use crate::{ accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues}, cycle::CycleRecoveryStrategy, + function::VerifyResult, plumbing::IngredientIndices, table::Table, zalsa::{transmute_data_mut_ptr, transmute_data_ptr, IngredientIndex, Zalsa}, @@ -60,7 +61,20 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { db: &'db dyn Database, input: Id, revision: Revision, - ) -> MaybeChangedAfter; + ) -> VerifyResult; + + /// Is the value for `input` in this ingredient a cycle head that is still provisional? + /// + /// In the case of nested cycles, we are not asking here whether the value is provisional due + /// to the outer cycle being unresolved, only whether its own cycle remains provisional. + fn is_provisional_cycle_head<'db>(&'db self, db: &'db dyn Database, input: Id) -> bool; + + /// Invoked when the current thread needs to wait for a result for the given `key_index`. + /// + /// A return value of `true` indicates that a result is now available. A return value of + /// `false` means that a cycle was encountered; the waited-on query is either already claimed + /// by the current thread, or by a thread waiting on the current thread. + fn wait_for(&self, db: &dyn Database, key_index: Id) -> bool; /// What were the inputs (if any) that were used to create the value at `key_index`. fn origin(&self, db: &dyn Database, key_index: Id) -> Option; @@ -97,6 +111,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { db: &dyn Database, executor: DatabaseKeyIndex, stale_output_key: Id, + provisional: bool, ); /// Returns the [`IngredientIndex`] of this ingredient. @@ -179,23 +194,3 @@ pub(crate) fn fmt_index( write!(fmt, "{debug_name}()") } } - -#[derive(Copy, Clone, Debug)] -pub enum MaybeChangedAfter { - /// The query result hasn't changed. - /// - /// The inner value tracks whether the memo or any of its dependencies have an accumulated value. - No(InputAccumulatedValues), - - /// The query's result has changed since the last revision or the query isn't cached yet. - Yes, -} - -impl From for MaybeChangedAfter { - fn from(value: bool) -> Self { - match value { - true => MaybeChangedAfter::Yes, - false => MaybeChangedAfter::No(InputAccumulatedValues::Empty), - } - } -} diff --git a/src/input.rs b/src/input.rs index cf99e9702..75d28b5d2 100644 --- a/src/input.rs +++ b/src/input.rs @@ -12,9 +12,10 @@ use input_field::FieldIngredientImpl; use crate::{ accumulator::accumulated_map::InputAccumulatedValues, - cycle::CycleRecoveryStrategy, + cycle::{CycleRecoveryStrategy, EMPTY_CYCLE_HEADS}, + function::VerifyResult, id::{AsId, FromIdWithDb}, - ingredient::{fmt_index, Ingredient, MaybeChangedAfter}, + ingredient::{fmt_index, Ingredient}, input::singleton::{Singleton, SingletonChoice}, key::{DatabaseKeyIndex, InputDependencyIndex}, plumbing::{Jar, Stamp}, @@ -182,6 +183,7 @@ impl IngredientImpl { stamp.durability, stamp.changed_at, InputAccumulatedValues::Empty, + &EMPTY_CYCLE_HEADS, ); &value.fields } @@ -220,10 +222,18 @@ impl Ingredient for IngredientImpl { _db: &dyn Database, _input: Id, _revision: Revision, - ) -> MaybeChangedAfter { + ) -> VerifyResult { // Input ingredients are just a counter, they store no data, they are immortal. // Their *fields* are stored in function ingredients elsewhere. - MaybeChangedAfter::No(InputAccumulatedValues::Empty) + VerifyResult::unchanged() + } + + fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + + fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { + true } fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { @@ -251,6 +261,7 @@ impl Ingredient for IngredientImpl { _db: &dyn Database, executor: DatabaseKeyIndex, stale_output_key: Id, + _provisional: bool, ) { unreachable!( "remove_stale_output({:?}, {:?}): input cannot be the output of a tracked function", diff --git a/src/input/input_field.rs b/src/input/input_field.rs index 362d3675c..6f5f9c226 100644 --- a/src/input/input_field.rs +++ b/src/input/input_field.rs @@ -1,5 +1,6 @@ use crate::cycle::CycleRecoveryStrategy; -use crate::ingredient::{fmt_index, Ingredient, MaybeChangedAfter}; +use crate::function::VerifyResult; +use crate::ingredient::{fmt_index, Ingredient}; use crate::input::Configuration; use crate::zalsa::IngredientIndex; use crate::zalsa_local::QueryOrigin; @@ -54,11 +55,18 @@ where db: &dyn Database, input: Id, revision: Revision, - ) -> MaybeChangedAfter { + ) -> VerifyResult { let zalsa = db.zalsa(); let value = >::data(zalsa, input); + VerifyResult::changed_if(value.stamps[self.field_index].changed_at > revision) + } + + fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } - MaybeChangedAfter::from(value.stamps[self.field_index].changed_at > revision) + fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { + true } fn origin(&self, _db: &dyn Database, _key_index: Id) -> Option { @@ -78,6 +86,7 @@ where _db: &dyn Database, _executor: DatabaseKeyIndex, _stale_output_key: Id, + _provisional: bool, ) { } diff --git a/src/interned.rs b/src/interned.rs index 86df09c5b..2e1ed07f3 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -1,8 +1,10 @@ use dashmap::SharedValue; use crate::accumulator::accumulated_map::InputAccumulatedValues; +use crate::cycle::EMPTY_CYCLE_HEADS; use crate::durability::Durability; -use crate::ingredient::{fmt_index, MaybeChangedAfter}; +use crate::function::VerifyResult; +use crate::ingredient::fmt_index; use crate::key::InputDependencyIndex; use crate::plumbing::{IngredientIndices, Jar}; use crate::table::memo::MemoTable; @@ -183,6 +185,7 @@ where Durability::MAX, self.reset_at, InputAccumulatedValues::Empty, + &EMPTY_CYCLE_HEADS, ); // Optimization to only get read lock on the map if the data has already been interned. @@ -287,8 +290,16 @@ where _db: &dyn Database, _input: Id, revision: Revision, - ) -> MaybeChangedAfter { - MaybeChangedAfter::from(revision < self.reset_at) + ) -> VerifyResult { + VerifyResult::changed_if(revision < self.reset_at) + } + + fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + + fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { + true } fn cycle_recovery_strategy(&self) -> crate::cycle::CycleRecoveryStrategy { @@ -316,6 +327,7 @@ where _db: &dyn Database, executor: DatabaseKeyIndex, stale_output_key: crate::Id, + _provisional: bool, ) { unreachable!( "remove_stale_output({:?}, {:?}): interned ids are not outputs", diff --git a/src/key.rs b/src/key.rs index f3e90bb10..61b748f06 100644 --- a/src/key.rs +++ b/src/key.rs @@ -1,9 +1,7 @@ use core::fmt; use crate::{ - accumulator::accumulated_map::InputAccumulatedValues, - cycle::CycleRecoveryStrategy, - ingredient::MaybeChangedAfter, + function::VerifyResult, zalsa::{IngredientIndex, Zalsa}, Database, Id, }; @@ -41,10 +39,11 @@ impl OutputDependencyIndex { zalsa: &Zalsa, db: &dyn Database, executor: DatabaseKeyIndex, + provisional: bool, ) { zalsa .lookup_ingredient(self.ingredient_index) - .remove_stale_output(db, executor, self.key_index) + .remove_stale_output(db, executor, self.key_index, provisional) } pub(crate) fn mark_validated_output( @@ -98,7 +97,7 @@ impl InputDependencyIndex { &self, db: &dyn Database, last_verified_at: crate::Revision, - ) -> MaybeChangedAfter { + ) -> VerifyResult { match self.key_index { // SAFETY: The `db` belongs to the ingredient Some(key_index) => unsafe { @@ -107,7 +106,7 @@ impl InputDependencyIndex { .maybe_changed_after(db, key_index, last_verified_at) }, // Data in tables themselves remain valid until the table as a whole is reset. - None => MaybeChangedAfter::No(InputAccumulatedValues::Empty), + None => VerifyResult::unchanged(), } } @@ -150,10 +149,6 @@ impl DatabaseKeyIndex { pub fn key_index(self) -> Id { self.key_index } - - pub(crate) fn cycle_recovery_strategy(self, db: &dyn Database) -> CycleRecoveryStrategy { - self.ingredient_index.cycle_recovery_strategy(db) - } } impl std::fmt::Debug for DatabaseKeyIndex { diff --git a/src/lib.rs b/src/lib.rs index 724506030..88d6fbd78 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,7 +34,7 @@ mod zalsa_local; pub use self::accumulator::Accumulator; pub use self::cancelled::Cancelled; -pub use self::cycle::Cycle; +pub use self::cycle::CycleRecoveryAction; pub use self::database::AsDynDatabase; pub use self::database::Database; pub use self::database_impl::DatabaseImpl; @@ -74,11 +74,11 @@ pub mod plumbing { pub use crate::array::Array; pub use crate::attach::attach; pub use crate::attach::with_attached_database; - pub use crate::cycle::Cycle; + pub use crate::cycle::CycleRecoveryAction; pub use crate::cycle::CycleRecoveryStrategy; pub use crate::database::current_revision; pub use crate::database::Database; - pub use crate::function::should_backdate_value; + pub use crate::function::values_equal; pub use crate::id::AsId; pub use crate::id::FromId; pub use crate::id::FromIdWithDb; @@ -122,6 +122,7 @@ pub mod plumbing { pub use salsa_macro_rules::setup_method_body; pub use salsa_macro_rules::setup_tracked_fn; pub use salsa_macro_rules::setup_tracked_struct; + pub use salsa_macro_rules::unexpected_cycle_initial; pub use salsa_macro_rules::unexpected_cycle_recovery; pub mod accumulator { diff --git a/src/runtime.rs b/src/runtime.rs index 116c6a752..e04ac8da7 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,19 +1,14 @@ use std::{ mem, - panic::panic_any, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, + sync::atomic::{AtomicBool, Ordering}, thread::ThreadId, }; use parking_lot::Mutex; use crate::{ - active_query::ActiveQuery, cycle::CycleRecoveryStrategy, durability::Durability, - key::DatabaseKeyIndex, table::Table, zalsa_local::ZalsaLocal, Cancelled, Cycle, Database, - Event, EventKind, Revision, + durability::Durability, key::DatabaseKeyIndex, table::Table, zalsa_local::ZalsaLocal, + Cancelled, Database, Event, EventKind, Revision, }; use self::dependency_graph::DependencyGraph; @@ -49,7 +44,12 @@ pub struct Runtime { pub(crate) enum WaitResult { Completed, Panicked, - Cycle(Cycle), +} + +#[derive(Clone, Debug)] +pub(crate) enum BlockResult { + Completed, + Cycle, } #[derive(Copy, Clone, Debug)] @@ -156,8 +156,8 @@ impl Runtime { r_new } - /// Block until `other_id` completes executing `database_key`; - /// panic or unwind in the case of a cycle. + /// Block until `other_id` completes executing `database_key`, or return `BlockResult::Cycle` + /// immediately in case of a cycle. /// /// `query_mutex_guard` is the guard for the current query's state; /// it will be dropped after we have successfully registered the @@ -167,34 +167,19 @@ impl Runtime { /// /// If the thread `other_id` panics, then our thread is considered /// cancelled, so this function will panic with a `Cancelled` value. - /// - /// # Cycle handling - /// - /// If the thread `other_id` already depends on the current thread, - /// and hence there is a cycle in the query graph, then this function - /// will unwind instead of returning normally. The method of unwinding - /// depends on the [`Self::mutual_cycle_recovery_strategy`] - /// of the cycle participants: - /// - /// * [`CycleRecoveryStrategy::Panic`]: panic with the [`Cycle`] as the value. - /// * [`CycleRecoveryStrategy::Fallback`]: initiate unwinding with [`CycleParticipant::unwind`]. - pub(crate) fn block_on_or_unwind( + pub(crate) fn block_on( &self, db: &dyn Database, local_state: &ZalsaLocal, database_key: DatabaseKeyIndex, other_id: ThreadId, query_mutex_guard: QueryMutexGuard, - ) { + ) -> BlockResult { let mut dg = self.dependency_graph.lock(); let thread_id = std::thread::current().id(); if dg.depends_on(other_id, thread_id) { - self.unblock_cycle_and_maybe_throw(db, local_state, &mut dg, database_key, other_id); - - // If the above fn returns, then (via cycle recovery) it has unblocked the - // cycle, so we can continue. - assert!(!dg.depends_on(other_id, thread_id)); + return BlockResult::Cycle; } db.salsa_event(&|| { @@ -218,126 +203,12 @@ impl Runtime { }); match result { - WaitResult::Completed => (), + WaitResult::Completed => BlockResult::Completed, // If the other thread panicked, then we consider this thread // cancelled. The assumption is that the panic will be detected // by the other thread and responded to appropriately. WaitResult::Panicked => Cancelled::PropagatedPanic.throw(), - - WaitResult::Cycle(c) => c.throw(), - } - } - - /// Handles a cycle in the dependency graph that was detected when the - /// current thread tried to block on `database_key_index` which is being - /// executed by `to_id`. If this function returns, then `to_id` no longer - /// depends on the current thread, and so we should continue executing - /// as normal. Otherwise, the function will throw a `Cycle` which is expected - /// to be caught by some frame on our stack. This occurs either if there is - /// a frame on our stack with cycle recovery (possibly the top one!) or if there - /// is no cycle recovery at all. - fn unblock_cycle_and_maybe_throw( - &self, - db: &dyn Database, - local_state: &ZalsaLocal, - dg: &mut DependencyGraph, - database_key_index: DatabaseKeyIndex, - to_id: ThreadId, - ) { - tracing::debug!( - "unblock_cycle_and_maybe_throw(database_key={:?})", - database_key_index - ); - - let (me_recovered, others_recovered, cycle) = local_state.with_query_stack(|from_stack| { - let from_id = std::thread::current().id(); - - // Make a "dummy stack frame". As we iterate through the cycle, we will collect the - // inputs from each participant. Then, if we are participating in cycle recovery, we - // will propagate those results to all participants. - let mut cycle_query = ActiveQuery::new(database_key_index); - - // Identify the cycle participants: - let cycle = { - let mut v = vec![]; - dg.for_each_cycle_participant( - from_id, - from_stack, - database_key_index, - to_id, - |aqs| { - aqs.iter_mut().for_each(|aq| { - cycle_query.add_from(aq); - v.push(aq.database_key_index); - }); - }, - ); - - // We want to give the participants in a deterministic order - // (at least for this execution, not necessarily across executions), - // no matter where it started on the stack. Find the minimum - // key and rotate it to the front. - - if let Some((_, index)) = v - .iter() - .enumerate() - .map(|(idx, key)| (key.ingredient_index.debug_name(db), idx)) - .min() - { - v.rotate_left(index); - } - - Cycle::new(Arc::new(v.into_boxed_slice())) - }; - tracing::debug!("cycle {cycle:?}, cycle_query {cycle_query:#?}"); - - // We can remove the cycle participants from the list of dependencies; - // they are a strongly connected component (SCC) and we only care about - // dependencies to things outside the SCC that control whether it will - // form again. - cycle_query.remove_cycle_participants(&cycle); - - // Mark each cycle participant that has recovery set, along with - // any frames that come after them on the same thread. Those frames - // are going to be unwound so that fallback can occur. - dg.for_each_cycle_participant(from_id, from_stack, database_key_index, to_id, |aqs| { - aqs.iter_mut() - .skip_while(|aq| { - match db - .zalsa() - .lookup_ingredient(aq.database_key_index.ingredient_index) - .cycle_recovery_strategy() - { - CycleRecoveryStrategy::Panic => true, - CycleRecoveryStrategy::Fallback => false, - } - }) - .for_each(|aq| { - tracing::debug!("marking {:?} for fallback", aq.database_key_index); - aq.take_inputs_from(&cycle_query); - assert!(aq.cycle.is_none()); - aq.cycle = Some(cycle.clone()); - }); - }); - - // Unblock every thread that has cycle recovery with a `WaitResult::Cycle`. - // They will throw the cycle, which will be caught by the frame that has - // cycle recovery so that it can execute that recovery. - let (me_recovered, others_recovered) = - dg.maybe_unblock_runtimes_in_cycle(from_id, from_stack, database_key_index, to_id); - (me_recovered, others_recovered, cycle) - }); - - if me_recovered { - // If the current thread has recovery, we want to throw - // so that it can begin. - cycle.throw() - } else if others_recovered { - // If other threads have recovery but we didn't: return and we will block on them. - } else { - // if nobody has recover, then we panic - panic_any(cycle); } } diff --git a/src/runtime/dependency_graph.rs b/src/runtime/dependency_graph.rs index a8da9d3ed..ba04f7ec9 100644 --- a/src/runtime/dependency_graph.rs +++ b/src/runtime/dependency_graph.rs @@ -31,7 +31,6 @@ pub(super) struct DependencyGraph { #[derive(Debug)] struct Edge { blocked_on_id: ThreadId, - blocked_on_key: DatabaseKeyIndex, stack: QueryStack, /// Signalled whenever a query with dependents completes. @@ -55,115 +54,6 @@ impl DependencyGraph { p == to_id } - /// Invokes `closure` with a `&mut ActiveQuery` for each query that participates in the cycle. - /// The cycle runs as follows: - /// - /// 1. The runtime `from_id`, which has the stack `from_stack`, would like to invoke `database_key`... - /// 2. ...but `database_key` is already being executed by `to_id`... - /// 3. ...and `to_id` is transitively dependent on something which is present on `from_stack`. - pub(super) fn for_each_cycle_participant( - &mut self, - from_id: ThreadId, - from_stack: &mut QueryStack, - database_key: DatabaseKeyIndex, - to_id: ThreadId, - mut closure: impl FnMut(&mut [ActiveQuery]), - ) { - debug_assert!(self.depends_on(to_id, from_id)); - - // To understand this algorithm, consider this [drawing](https://is.gd/TGLI9v): - // - // database_key = QB2 - // from_id = A - // to_id = B - // from_stack = [QA1, QA2, QA3] - // - // self.edges[B] = { C, QC2, [QB1..QB3] } - // self.edges[C] = { A, QA2, [QC1..QC3] } - // - // The cyclic - // edge we have - // failed to add. - // : - // A : B C - // : - // QA1 v QB1 QC1 - // ┌► QA2 ┌──► QB2 ┌─► QC2 - // │ QA3 ───┘ QB3 ──┘ QC3 ───┐ - // │ │ - // └───────────────────────────────┘ - // - // Final output: [QB2, QB3, QC2, QC3, QA2, QA3] - - let mut id = to_id; - let mut key = database_key; - while id != from_id { - // Looking at the diagram above, the idea is to - // take the edge from `to_id` starting at `key` - // (inclusive) and down to the end. We can then - // load up the next thread (i.e., we start at B/QB2, - // and then load up the dependency on C/QC2). - let edge = self.edges.get_mut(&id).unwrap(); - closure(strip_prefix_query_stack_mut(&mut edge.stack, key)); - id = edge.blocked_on_id; - key = edge.blocked_on_key; - } - - // Finally, we copy in the results from `from_stack`. - closure(strip_prefix_query_stack_mut(from_stack, key)); - } - - /// Unblock each blocked runtime (excluding the current one) if some - /// query executing in that runtime is participating in cycle fallback. - /// - /// Returns a boolean (Current, Others) where: - /// * Current is true if the current runtime has cycle participants - /// with fallback; - /// * Others is true if other runtimes were unblocked. - pub(super) fn maybe_unblock_runtimes_in_cycle( - &mut self, - from_id: ThreadId, - from_stack: &QueryStack, - database_key: DatabaseKeyIndex, - to_id: ThreadId, - ) -> (bool, bool) { - // See diagram in `for_each_cycle_participant`. - let mut id = to_id; - let mut key = database_key; - let mut others_unblocked = false; - while id != from_id { - let edge = self.edges.get(&id).unwrap(); - let next_id = edge.blocked_on_id; - let next_key = edge.blocked_on_key; - - if let Some(cycle) = strip_prefix_query_stack(&edge.stack, key) - .iter() - .rev() - .find_map(|aq| aq.cycle.clone()) - { - // Remove `id` from the list of runtimes blocked on `next_key`: - self.query_dependents - .get_mut(&next_key) - .unwrap() - .retain(|r| *r != id); - - // Unblock runtime so that it can resume execution once lock is released: - self.unblock_runtime(id, WaitResult::Cycle(cycle)); - - others_unblocked = true; - } - - id = next_id; - key = next_key; - } - - let this_unblocked = strip_prefix_query_stack(from_stack, key) - .iter() - .any(|aq| aq.cycle.is_some()); - - (this_unblocked, others_unblocked) - } - /// Modifies the graph so that `from_id` is blocked /// on `database_key`, which is being computed by /// `to_id`. @@ -219,7 +109,6 @@ impl DependencyGraph { from_id, Edge { blocked_on_id: to_id, - blocked_on_key: database_key, stack: from_stack, condvar: condvar.clone(), }, @@ -260,22 +149,3 @@ impl DependencyGraph { edge.condvar.notify_one(); } } - -fn strip_prefix_query_stack(stack_mut: &[ActiveQuery], key: DatabaseKeyIndex) -> &[ActiveQuery] { - let prefix = stack_mut - .iter() - .take_while(|p| p.database_key_index != key) - .count(); - &stack_mut[prefix..] -} - -fn strip_prefix_query_stack_mut( - stack_mut: &mut [ActiveQuery], - key: DatabaseKeyIndex, -) -> &mut [ActiveQuery] { - let prefix = stack_mut - .iter() - .take_while(|p| p.database_key_index != key) - .count(); - &mut stack_mut[prefix..] -} diff --git a/src/table/sync.rs b/src/table/sync.rs index 521982e62..97f175467 100644 --- a/src/table/sync.rs +++ b/src/table/sync.rs @@ -4,7 +4,7 @@ use parking_lot::Mutex; use crate::{ key::DatabaseKeyIndex, - runtime::WaitResult, + runtime::{BlockResult, WaitResult}, zalsa::{MemoIngredientIndex, Zalsa}, Database, }; @@ -26,6 +26,12 @@ struct SyncState { anyone_waiting: bool, } +pub(crate) enum ClaimResult<'a> { + Retry, + Cycle, + Claimed(ClaimGuard<'a>), +} + impl SyncTable { #[inline] pub(crate) fn claim<'me>( @@ -34,7 +40,7 @@ impl SyncTable { zalsa: &'me Zalsa, database_key_index: DatabaseKeyIndex, memo_ingredient_index: MemoIngredientIndex, - ) -> Option> { + ) -> ClaimResult<'me> { let mut syncs = self.syncs.lock(); let thread_id = std::thread::current().id(); @@ -46,26 +52,35 @@ impl SyncTable { id: thread_id, anyone_waiting: false, }); - Some(ClaimGuard { + ClaimResult::Claimed(ClaimGuard { database_key_index, memo_ingredient_index, zalsa, sync_table: self, + _padding: false, }) } Some(SyncState { id: other_id, anyone_waiting, }) => { + // NB: `Ordering::Relaxed` is sufficient here, + // as there are no loads that are "gated" on this + // value. Everything that is written is also protected + // by a lock that must be acquired. The role of this + // boolean is to decide *whether* to acquire the lock, + // not to gate future atomic reads. *anyone_waiting = true; - zalsa.runtime().block_on_or_unwind( + match zalsa.runtime().block_on( db.as_dyn_database(), db.zalsa_local(), database_key_index, *other_id, syncs, - ); - None + ) { + BlockResult::Completed => ClaimResult::Retry, + BlockResult::Cycle => ClaimResult::Cycle, + } } } } @@ -79,6 +94,9 @@ pub(crate) struct ClaimGuard<'me> { memo_ingredient_index: MemoIngredientIndex, zalsa: &'me Zalsa, sync_table: &'me SyncTable, + // Reduce the size of ClaimResult by making more niches available in ClaimGuard; this fits into + // the padding of ClaimGuard so doesn't increase its size. + _padding: bool, } impl ClaimGuard<'_> { diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index c88a47b04..a3b01dc1d 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -5,8 +5,9 @@ use tracked_field::FieldIngredientImpl; use crate::{ accumulator::accumulated_map::InputAccumulatedValues, - cycle::CycleRecoveryStrategy, - ingredient::{fmt_index, Ingredient, Jar, MaybeChangedAfter}, + cycle::{CycleRecoveryStrategy, EMPTY_CYCLE_HEADS}, + function::VerifyResult, + ingredient::{fmt_index, Ingredient, Jar}, key::{DatabaseKeyIndex, InputDependencyIndex}, plumbing::ZalsaLocal, revision::OptionalAtomicRevision, @@ -586,7 +587,7 @@ where /// Using this method on an entity id that MAY be used in the current revision will lead to /// unspecified results (but not UB). See [`InternedIngredient::delete_index`] for more /// discussion and important considerations. - pub(crate) fn delete_entity(&self, db: &dyn crate::Database, id: Id) { + pub(crate) fn delete_entity(&self, db: &dyn crate::Database, id: Id, provisional: bool) { db.salsa_event(&|| { Event::new(crate::EventKind::DidDiscard { key: self.database_key_index(id), @@ -604,7 +605,7 @@ where None => { panic!("cannot delete write-locked id `{id:?}`; value leaked across threads"); } - Some(r) if r == current_revision => panic!( + Some(r) if !provisional && r == current_revision => panic!( "cannot delete read-locked id `{id:?}`; value leaked across threads or user functions not deterministic" ), Some(r) => { @@ -632,7 +633,7 @@ where db.salsa_event(&|| Event::new(EventKind::DidDiscard { key: executor })); for stale_output in memo.origin().outputs() { - stale_output.remove_stale_output(zalsa, db, executor); + stale_output.remove_stale_output(zalsa, db, executor, provisional); } } @@ -681,6 +682,7 @@ where data.durability, field_changed_at, InputAccumulatedValues::Empty, + &EMPTY_CYCLE_HEADS, ); unsafe { self.to_self_ref(&data.fields) } @@ -707,6 +709,7 @@ where data.durability, data.created_at, InputAccumulatedValues::Empty, + &EMPTY_CYCLE_HEADS, ); unsafe { self.to_self_ref(&data.fields) } @@ -740,11 +743,19 @@ where db: &dyn Database, input: Id, revision: Revision, - ) -> MaybeChangedAfter { + ) -> VerifyResult { let zalsa = db.zalsa(); let data = Self::data(zalsa.table(), input); - MaybeChangedAfter::from(data.created_at > revision) + VerifyResult::changed_if(data.created_at > revision) + } + + fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + + fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { + true } fn cycle_recovery_strategy(&self) -> CycleRecoveryStrategy { @@ -771,12 +782,13 @@ where db: &dyn Database, _executor: DatabaseKeyIndex, stale_output_key: crate::Id, + provisional: bool, ) { // This method is called when, in prior revisions, // `executor` creates a tracked struct `salsa_output_key`, // but it did not in the current revision. // In that case, we can delete `stale_output_key` and any data associated with it. - self.delete_entity(db.as_dyn_database(), stale_output_key); + self.delete_entity(db.as_dyn_database(), stale_output_key, provisional); } fn fmt_index(&self, index: Option, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/src/tracked_struct/tracked_field.rs b/src/tracked_struct/tracked_field.rs index b69ffebd1..9b264949c 100644 --- a/src/tracked_struct/tracked_field.rs +++ b/src/tracked_struct/tracked_field.rs @@ -1,10 +1,6 @@ use std::marker::PhantomData; -use crate::{ - ingredient::{Ingredient, MaybeChangedAfter}, - zalsa::IngredientIndex, - Database, Id, -}; +use crate::{function::VerifyResult, ingredient::Ingredient, zalsa::IngredientIndex, Database, Id}; use super::{Configuration, Value}; @@ -60,11 +56,19 @@ where db: &'db dyn Database, input: Id, revision: crate::Revision, - ) -> MaybeChangedAfter { + ) -> VerifyResult { let zalsa = db.zalsa(); let data = >::data(zalsa.table(), input); let field_changed_at = data.revisions[self.field_index]; - MaybeChangedAfter::from(field_changed_at > revision) + VerifyResult::changed_if(field_changed_at > revision) + } + + fn is_provisional_cycle_head<'db>(&'db self, _db: &'db dyn Database, _input: Id) -> bool { + false + } + + fn wait_for(&self, _db: &dyn Database, _key_index: Id) -> bool { + true } fn origin( @@ -89,6 +93,7 @@ where _db: &dyn Database, _executor: crate::DatabaseKeyIndex, _stale_output_key: crate::Id, + _provisional: bool, ) { panic!("tracked field ingredients have no outputs") } diff --git a/src/zalsa.rs b/src/zalsa.rs index 2a4869615..f6a29e7cc 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -9,7 +9,6 @@ use std::num::NonZeroU32; use std::panic::RefUnwindSafe; use std::sync::atomic::Ordering; -use crate::cycle::CycleRecoveryStrategy; use crate::ingredient::{Ingredient, Jar}; use crate::nonce::{Nonce, NonceGenerator}; use crate::runtime::Runtime; @@ -90,18 +89,9 @@ impl IngredientIndex { self.0 as usize } - pub(crate) fn cycle_recovery_strategy(self, db: &dyn Database) -> CycleRecoveryStrategy { - db.zalsa().lookup_ingredient(self).cycle_recovery_strategy() - } - pub fn successor(self, index: usize) -> Self { IngredientIndex(self.0 + 1 + index as u32) } - - /// Return the "debug name" of this ingredient (e.g., the name of the tracked struct it represents) - pub(crate) fn debug_name(self, db: &dyn Database) -> &'static str { - db.zalsa().lookup_ingredient(self).debug_name() - } } /// A special secondary index *just* for ingredients that attach diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index ad3ece3a7..b9db8b641 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -1,10 +1,11 @@ -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHashSet}; use tracing::debug; use crate::accumulator::accumulated_map::{ AccumulatedMap, AtomicInputAccumulatedValues, InputAccumulatedValues, }; use crate::active_query::ActiveQuery; +use crate::cycle::CycleHeads; use crate::durability::Durability; use crate::key::{DatabaseKeyIndex, InputDependencyIndex, OutputDependencyIndex}; use crate::runtime::StampedValue; @@ -15,7 +16,6 @@ use crate::tracked_struct::{Disambiguator, Identity, IdentityHash, IdentityMap}; use crate::zalsa::IngredientIndex; use crate::Accumulator; use crate::Cancelled; -use crate::Cycle; use crate::Id; use crate::Revision; use std::cell::RefCell; @@ -170,6 +170,7 @@ impl ZalsaLocal { durability: Durability, changed_at: Revision, accumulated: InputAccumulatedValues, + cycle_heads: &CycleHeads, ) { debug!( "report_tracked_read(input={:?}, durability={:?}, changed_at={:?})", @@ -177,32 +178,7 @@ impl ZalsaLocal { ); self.with_query_stack(|stack| { if let Some(top_query) = stack.last_mut() { - top_query.add_read(input, durability, changed_at, accumulated); - - // We are a cycle participant: - // - // C0 --> ... --> Ci --> Ci+1 -> ... -> Cn --> C0 - // ^ ^ - // : | - // This edge -----+ | - // | - // | - // N0 - // - // In this case, the value we have just read from `Ci+1` - // is actually the cycle fallback value and not especially - // interesting. We unwind now with `CycleParticipant` to avoid - // executing the rest of our query function. This unwinding - // will be caught and our own fallback value will be used. - // - // Note that `Ci+1` may` have *other* callers who are not - // participants in the cycle (e.g., N0 in the graph above). - // They will not have the `cycle` marker set in their - // stack frames, so they will just read the fallback value - // from `Ci+1` and continue on their merry way. - if let Some(cycle) = &top_query.cycle { - cycle.clone().throw() - } + top_query.add_read(input, durability, changed_at, accumulated, cycle_heads); } }) } @@ -330,12 +306,36 @@ pub(crate) struct QueryRevisions { pub(super) tracked_struct_ids: IdentityMap, pub(super) accumulated: Option>, + /// [`InputAccumulatedValues::Empty`] if any input read during the query's execution /// has any direct or indirect accumulated values. pub(super) accumulated_inputs: AtomicInputAccumulatedValues, + + /// This result was computed based on provisional values from + /// these cycle heads. The "cycle head" is the query responsible + /// for managing a fixpoint iteration. In a cycle like + /// `--> A --> B --> C --> A`, the cycle head is query `A`: it is + /// the query whose value is requested while it is executing, + /// which must provide the initial provisional value and decide, + /// after each iteration, whether the cycle has converged or must + /// iterate again. + pub(super) cycle_heads: CycleHeads, } impl QueryRevisions { + pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex, revision: Revision) -> Self { + let cycle_heads = FxHashSet::from_iter([query]).into(); + Self { + changed_at: revision, + durability: Durability::MAX, + origin: QueryOrigin::FixpointInitial, + tracked_struct_ids: Default::default(), + accumulated: Default::default(), + accumulated_inputs: Default::default(), + cycle_heads, + } + } + pub(crate) fn stamped_value(&self, value: V) -> StampedValue { self.stamp_template().stamp(value) } @@ -381,6 +381,9 @@ pub enum QueryOrigin { /// The [`QueryEdges`] argument contains a listing of all the inputs we saw /// (but we know there were more). DerivedUntracked(QueryEdges), + + /// The value is an initial provisional value for a query that supports fixpoint iteration. + FixpointInitial, } impl QueryOrigin { @@ -388,7 +391,7 @@ impl QueryOrigin { pub(crate) fn inputs(&self) -> impl DoubleEndedIterator + '_ { let opt_edges = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), - QueryOrigin::Assigned(_) => None, + QueryOrigin::Assigned(_) | QueryOrigin::FixpointInitial => None, }; opt_edges.into_iter().flat_map(|edges| edges.inputs()) } @@ -397,7 +400,7 @@ impl QueryOrigin { pub(crate) fn outputs(&self) -> impl DoubleEndedIterator + '_ { let opt_edges = match self { QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges), - QueryOrigin::Assigned(_) => None, + QueryOrigin::Assigned(_) | QueryOrigin::FixpointInitial => None, }; opt_edges.into_iter().flat_map(|edges| edges.outputs()) } @@ -509,18 +512,8 @@ impl ActiveQueryGuard<'_> { // Extract accumulated inputs. let popped_query = self.complete(); - // If this frame were a cycle participant, it would have unwound. - assert!(popped_query.cycle.is_none()); - popped_query.into_revisions() } - - /// If the active query is registered as a cycle participant, remove and - /// return that cycle. - pub(crate) fn take_cycle(&self) -> Option { - self.local_state - .with_query_stack(|stack| stack.last_mut()?.cycle.take()) - } } impl Drop for ActiveQueryGuard<'_> { diff --git a/tests/common/mod.rs b/tests/common/mod.rs index 19f818b65..75d22073e 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -141,3 +141,41 @@ impl HasLogger for ExecuteValidateLoggerDatabase { &self.logger } } + +/// Trait implemented by databases that lets them provide a fixed u32 value. +pub trait HasValue { + fn get_value(&self) -> u32; +} + +#[salsa::db] +pub trait ValueDatabase: HasValue + Database {} + +#[salsa::db] +impl ValueDatabase for Db {} + +#[salsa::db] +#[derive(Clone, Default)] +pub struct DatabaseWithValue { + storage: Storage, + value: u32, +} + +impl HasValue for DatabaseWithValue { + fn get_value(&self) -> u32 { + self.value + } +} + +#[salsa::db] +impl Database for DatabaseWithValue { + fn salsa_event(&self, _event: &dyn Fn() -> salsa::Event) {} +} + +impl DatabaseWithValue { + pub fn new(value: u32) -> Self { + Self { + storage: Default::default(), + value, + } + } +} diff --git a/tests/compile-fail/get-set-on-private-input-field.rs b/tests/compile-fail/get-set-on-private-input-field.rs index 5ecec5836..345590b75 100644 --- a/tests/compile-fail/get-set-on-private-input-field.rs +++ b/tests/compile-fail/get-set-on-private-input-field.rs @@ -1,5 +1,3 @@ -use salsa::prelude::*; - mod a { #[salsa::input] pub struct MyInput { diff --git a/tests/compile-fail/get-set-on-private-input-field.stderr b/tests/compile-fail/get-set-on-private-input-field.stderr index b8dcca66d..40acd8c2d 100644 --- a/tests/compile-fail/get-set-on-private-input-field.stderr +++ b/tests/compile-fail/get-set-on-private-input-field.stderr @@ -1,17 +1,17 @@ error[E0624]: method `field` is private - --> tests/compile-fail/get-set-on-private-input-field.rs:14:11 + --> tests/compile-fail/get-set-on-private-input-field.rs:12:11 | -4 | #[salsa::input] +2 | #[salsa::input] | --------------- private method defined here ... -14 | input.field(&db); +12 | input.field(&db); | ^^^^^ private method error[E0624]: method `set_field` is private - --> tests/compile-fail/get-set-on-private-input-field.rs:15:11 + --> tests/compile-fail/get-set-on-private-input-field.rs:13:11 | -4 | #[salsa::input] +2 | #[salsa::input] | --------------- private method defined here ... -15 | input.set_field(&mut db).to(23); +13 | input.set_field(&mut db).to(23); | ^^^^^^^^^ private method diff --git a/tests/cycle.rs b/tests/cycle.rs new file mode 100644 index 000000000..fe5875fc0 --- /dev/null +++ b/tests/cycle.rs @@ -0,0 +1,1004 @@ +//! Test cases for fixpoint iteration cycle resolution. +//! +//! These test cases use a generic query setup that allows constructing arbitrary dependency +//! graphs, and attempts to achieve good coverage of various cases. +mod common; +use common::{ExecuteValidateLoggerDatabase, LogDatabase}; +use expect_test::expect; +use salsa::{CycleRecoveryAction, Database as Db, DatabaseImpl as DbImpl, Durability, Setter}; +use test_log::test; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, salsa::Update)] +enum Value { + N(u8), + OutOfBounds, + TooManyIterations, +} + +impl Value { + fn to_value(self) -> Option { + if let Self::N(val) = self { + Some(val) + } else { + None + } + } +} + +/// A vector of inputs a query can evaluate to get an iterator of values to operate on. +/// +/// This allows creating arbitrary query graphs between the four queries below (`min_iterate`, +/// `max_iterate`, `min_panic`, `max_panic`) for testing cycle behaviors. +#[salsa::input] +struct Inputs { + inputs: Vec, +} + +impl Inputs { + fn values(self, db: &dyn Db) -> impl Iterator + '_ { + self.inputs(db).into_iter().map(|input| input.eval(db)) + } +} + +/// A single input, evaluating to a single [`Value`]. +#[derive(Clone, Debug)] +enum Input { + /// a simple value + Value(Value), + + /// a simple value, reported as an untracked read + UntrackedRead(Value), + + /// minimum of the given inputs, with fixpoint iteration on cycles + MinIterate(Inputs), + + /// maximum of the given inputs, with fixpoint iteration on cycles + MaxIterate(Inputs), + + /// minimum of the given inputs, panicking on cycles + MinPanic(Inputs), + + /// maximum of the given inputs, panicking on cycles + MaxPanic(Inputs), + + /// value of the given input, plus one; propagates error values + Successor(Box), + + /// successor, converts error values to zero + SuccessorOrZero(Box), +} + +impl Input { + fn eval(self, db: &dyn Db) -> Value { + match self { + Self::Value(value) => value, + Self::UntrackedRead(value) => { + db.report_untracked_read(); + value + } + Self::MinIterate(inputs) => min_iterate(db, inputs), + Self::MaxIterate(inputs) => max_iterate(db, inputs), + Self::MinPanic(inputs) => min_panic(db, inputs), + Self::MaxPanic(inputs) => max_panic(db, inputs), + Self::Successor(input) => match input.eval(db) { + Value::N(num) => Value::N(num + 1), + other => other, + }, + Self::SuccessorOrZero(input) => match input.eval(db) { + Value::N(num) => Value::N(num + 1), + _ => Value::N(0), + }, + } + } + + fn assert(self, db: &dyn Db, expected: Value) { + assert_eq!(self.eval(db), expected) + } + + fn assert_value(self, db: &dyn Db, expected: u8) { + self.assert(db, Value::N(expected)) + } + + fn assert_bounds(self, db: &dyn Db) { + self.assert(db, Value::OutOfBounds) + } + + fn assert_count(self, db: &dyn Db) { + self.assert(db, Value::TooManyIterations) + } +} + +const MIN_VALUE: u8 = 10; +const MAX_VALUE: u8 = 245; +const MAX_ITERATIONS: u32 = 3; + +/// Recover from a cycle by falling back to `Value::OutOfBounds` if the value is out of bounds, +/// `Value::TooManyIterations` if we've iterated more than `MAX_ITERATIONS` times, or else +/// iterating again. +fn cycle_recover( + _db: &dyn Db, + value: &Value, + count: u32, + _inputs: Inputs, +) -> CycleRecoveryAction { + if value + .to_value() + .is_some_and(|val| val <= MIN_VALUE || val >= MAX_VALUE) + { + CycleRecoveryAction::Fallback(Value::OutOfBounds) + } else if count > MAX_ITERATIONS { + CycleRecoveryAction::Fallback(Value::TooManyIterations) + } else { + CycleRecoveryAction::Iterate + } +} + +/// Fold an iterator of `Value` into a `Value`, given some binary operator to apply to two `u8`. +/// `Value::TooManyIterations` and `Value::OutOfBounds` will always propagate, with +/// `Value::TooManyIterations` taking precedence. +fn fold_values(values: impl IntoIterator, op: F) -> Value +where + F: Fn(u8, u8) -> u8, +{ + values + .into_iter() + .fold(None, |accum, elem| { + let Some(accum) = accum else { + return Some(elem); + }; + match (accum, elem) { + (Value::TooManyIterations, _) | (_, Value::TooManyIterations) => { + Some(Value::TooManyIterations) + } + (Value::OutOfBounds, _) | (_, Value::OutOfBounds) => Some(Value::OutOfBounds), + (Value::N(val1), Value::N(val2)) => Some(Value::N(op(val1, val2))), + } + }) + .expect("inputs should not be empty") +} + +/// Query minimum value of inputs, with cycle recovery. +#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=min_initial)] +fn min_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { + fold_values(inputs.values(db), u8::min) +} + +fn min_initial(_db: &dyn Db, _inputs: Inputs) -> Value { + Value::N(255) +} + +/// Query maximum value of inputs, with cycle recovery. +#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=max_initial)] +fn max_iterate<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { + fold_values(inputs.values(db), u8::max) +} + +fn max_initial(_db: &dyn Db, _inputs: Inputs) -> Value { + Value::N(0) +} + +/// Query minimum value of inputs, without cycle recovery. +#[salsa::tracked] +fn min_panic<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { + fold_values(inputs.values(db), u8::min) +} + +/// Query maximum value of inputs, without cycle recovery. +#[salsa::tracked] +fn max_panic<'db>(db: &'db dyn Db, inputs: Inputs) -> Value { + fold_values(inputs.values(db), u8::max) +} + +fn untracked(num: u8) -> Input { + Input::UntrackedRead(Value::N(num)) +} + +fn value(num: u8) -> Input { + Input::Value(Value::N(num)) +} + +// Diagram nomenclature for nodes: Each node is represented as a:xx(ii), where `a` is a sequential +// identifier from `a`, `b`, `c`..., xx is one of the four query kinds: +// - `Ni` for `min_iterate` +// - `Xi` for `max_iterate` +// - `Np` for `min_panic` +// - `Xp` for `max_panic` +//\ +// and `ii` is the inputs for that query, represented as a comma-separated list, with each +// component representing an input: +// - `a`, `b`, `c`... where the input is another node, +// - `uXX` for `UntrackedRead(XX)` +// - `vXX` for `Value(XX)` +// - `sY` for `Successor(Y)` +// - `zY` for `SuccessorOrZero(Y)` +// +// We always enter from the top left node in the diagram. + +/// a:Np(a) -+ +/// ^ | +/// +--------+ +/// +/// Simple self-cycle, no iteration, should panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn self_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + a_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.eval(&db); +} + +/// a:Np(u10, a) -+ +/// ^ | +/// +-------------+ +/// +/// Simple self-cycle with untracked read, no iteration, should panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn self_untracked_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + a_in.set_inputs(&mut db).to(vec![untracked(10), a.clone()]); + + a.eval(&db); +} + +/// a:Ni(a) -+ +/// ^ | +/// +--------+ +/// +/// Simple self-cycle, iteration converges on initial value. +#[test] +fn self_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + a_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert_value(&db, 255); +} + +/// a:Ni(b) --> b:Np(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, one with iteration and one without. +/// If we enter from the one with iteration, we converge on its initial value. +#[test] +fn two_mixed_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert_value(&db, 255); +} + +/// a:Np(b) --> b:Ni(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, one with iteration and one without. +/// If we enter from the one with no iteration, we panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn two_mixed_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(b_in); + let b = Input::MinIterate(a_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.eval(&db); +} + +/// a:Ni(b) --> b:Xi(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, both with iteration. +/// We converge on the initial value of whichever we first enter from. +#[test] +fn two_iterate_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MaxIterate(b_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert_value(&db, 255); + b.assert_value(&db, 255); +} + +/// a:Xi(b) --> b:Ni(a) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, both with iteration. +/// We converge on the initial value of whichever we enter from. +/// (Same setup as above test, different query order.) +#[test] +fn two_iterate_converge_initial_value_2() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MinIterate(b_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert_value(&db, 0); + b.assert_value(&db, 0); +} + +/// a:Np(b) --> b:Ni(c) --> c:Xp(b) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, enter indirectly at node with iteration, converge on its initial value. +#[test] +fn two_indirect_iterate_converge_initial_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![b]); + + a.assert_value(&db, 255); +} + +/// a:Xp(b) --> b:Np(c) --> c:Xi(b) +/// ^ | +/// +-----------------+ +/// +/// Two-query cycle, enter indirectly at node without iteration, panic. +#[test] +#[should_panic(expected = "dependency graph cycle")] +fn two_indirect_panic() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinPanic(b_in); + let c = Input::MaxIterate(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![b]); + + a.eval(&db); +} + +/// a:Np(b) -> b:Ni(v200,c) -> c:Xp(b) +/// ^ | +/// +---------------------+ +/// +/// Two-query cycle, converges to non-initial value. +#[test] +fn two_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(200), c]); + c_in.set_inputs(&mut db).to(vec![b]); + + a.assert_value(&db, 200); +} + +/// a:Xp(b) -> b:Xi(v20,c) -> c:Xp(sb) +/// ^ | +/// +---------------------+ +/// +/// Two-query cycle, falls back due to >3 iterations. +#[test] +fn two_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxPanic(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(20), c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Successor(Box::new(b))]); + + a.assert_count(&db); +} + +/// a:Xp(b) -> b:Xi(v20,c) -> c:Xp(zb) +/// ^ | +/// +---------------------+ +/// +/// Two-query cycle, falls back but fallback does not converge. +#[test] +#[should_panic(expected = "fallback did not converge")] +fn two_fallback_diverge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxPanic(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(20), c.clone()]); + c_in.set_inputs(&mut db) + .to(vec![Input::SuccessorOrZero(Box::new(b))]); + + a.assert_count(&db); +} + +/// a:Xp(b) -> b:Xi(v244,c) -> c:Xp(sb) +/// ^ | +/// +---------------------+ +/// +/// Two-query cycle, falls back due to value reaching >MAX_VALUE (we start at 244 and each +/// iteration increments until we reach >245). +#[test] +fn two_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxPanic(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(244), c]); + c_in.set_inputs(&mut db) + .to(vec![Input::Successor(Box::new(b))]); + + a.assert_bounds(&db); +} + +/// a:Ni(b) -> b:Np(a, c) -> c:Np(v25, a) +/// ^ | | +/// +----------+------------------------+ +/// +/// Three-query cycle, (b) and (c) both depend on (a). We converge on 25. +#[test] +fn three_fork_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db).to(vec![value(25), a.clone()]); + + a.assert_value(&db, 25); +} + +/// a:Ni(b) -> b:Ni(a, c) -> c:Np(v25, b) +/// ^ | ^ | +/// +----------+ +----------+ +/// +/// Layered cycles. We converge on 25. +#[test] +fn layered_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db).to(vec![value(25), b]); + + a.assert_value(&db, 25); +} + +/// a:Xi(b) -> b:Xi(a, c) -> c:Xp(v25, sb) +/// ^ | ^ | +/// +----------+ +----------+ +/// +/// Layered cycles. We hit max iterations and fall back. +#[test] +fn layered_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db) + .to(vec![value(25), Input::Successor(Box::new(b))]); + a.assert_count(&db); +} + +/// a:Xi(b) -> b:Xi(a, c) -> c:Xp(v243, sb) +/// ^ | ^ | +/// +----------+ +----------+ +/// +/// Layered cycles. We hit max value and fall back. +#[test] +fn layered_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![a.clone(), c]); + c_in.set_inputs(&mut db) + .to(vec![value(243), Input::Successor(Box::new(b))]); + + a.assert_bounds(&db); +} + +/// a:Ni(b) -> b:Ni(c) -> c:Np(v25, a, b) +/// ^ ^ | +/// +----------+------------------------+ +/// +/// Nested cycles. We converge on 25. +#[test] +fn nested_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![value(25), a.clone(), b]); + + a.assert_value(&db, 25); +} + +/// a:Ni(b) -> b:Ni(c) -> c:Np(v25, b, a) +/// ^ ^ | +/// +----------+------------------------+ +/// +/// Nested cycles, inner first. We converge on 25. +#[test] +fn nested_inner_first_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db).to(vec![value(25), b, a.clone()]); + + a.assert_value(&db, 25); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v25, a, sb) +/// ^ ^ | +/// +----------+-------------------------+ +/// +/// Nested cycles. We hit max iterations and fall back. +#[test] +fn nested_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db) + .to(vec![value(25), a.clone(), Input::Successor(Box::new(b))]); + + a.assert_count(&db); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v25, b, sa) +/// ^ ^ | +/// +----------+-------------------------+ +/// +/// Nested cycles, inner first. We hit max iterations and fall back. +#[test] +fn nested_inner_first_fallback_count() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db) + .to(vec![value(25), b, Input::Successor(Box::new(a.clone()))]); + + a.assert_count(&db); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v243, a, sb) +/// ^ ^ | +/// +----------+--------------------------+ +/// +/// Nested cycles. We hit max value and fall back. +#[test] +fn nested_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c.clone()]); + c_in.set_inputs(&mut db).to(vec![ + value(243), + a.clone(), + Input::Successor(Box::new(b.clone())), + ]); + a.assert_bounds(&db); + b.assert_bounds(&db); + c.assert_bounds(&db); +} + +/// a:Xi(b) -> b:Xi(c) -> c:Xp(v243, b, sa) +/// ^ ^ | +/// +----------+--------------------------+ +/// +/// Nested cycles, inner first. We hit max value and fall back. +#[test] +fn nested_inner_first_fallback_value() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c]); + c_in.set_inputs(&mut db) + .to(vec![value(243), b, Input::Successor(Box::new(a.clone()))]); + + a.assert_bounds(&db); +} + +/// a:Ni(b) -> b:Ni(c, a) -> c:Np(v25, a, b) +/// ^ ^ | | +/// +----------+--------|------------------+ +/// | | +/// +-------------------+ +/// +/// Nested cycles, double head. We converge on 25. +#[test] +fn nested_double_converge() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c, a.clone()]); + c_in.set_inputs(&mut db).to(vec![value(25), a.clone(), b]); + + a.assert_value(&db, 25); +} + +// Multiple-revision cycles + +/// a:Ni(b) --> b:Np(a) +/// ^ | +/// +-----------------+ +/// +/// a:Ni(b) --> b:Np(v30) +/// +/// Cycle becomes not-a-cycle in next revision. +#[test] +fn cycle_becomes_non_cycle() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.clone().assert_value(&db, 255); + + b_in.set_inputs(&mut db).to(vec![value(30)]); + + a.assert_value(&db, 30); +} + +/// a:Ni(b) --> b:Np(v30) +/// +/// a:Ni(b) --> b:Np(a) +/// ^ | +/// +-----------------+ +/// +/// Non-cycle becomes a cycle in next revision. +#[test] +fn non_cycle_becomes_cycle() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinPanic(b_in); + a_in.set_inputs(&mut db).to(vec![b]); + b_in.set_inputs(&mut db).to(vec![value(30)]); + + a.clone().assert_value(&db, 30); + + b_in.set_inputs(&mut db).to(vec![a.clone()]); + + a.assert_value(&db, 255); +} + +/// a:Xi(b) -> b:Xi(c, a) -> c:Xp(v25, a, sb) +/// ^ ^ | | +/// +----------+--------|-------------------+ +/// | | +/// +-------------------+ +/// +/// Nested cycles, double head. We hit max iterations and fall back, then max value on the next +/// revision, then converge on the next. +#[test] +fn nested_double_multiple_revisions() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MaxIterate(a_in); + let b = Input::MaxIterate(b_in); + let c = Input::MaxPanic(c_in); + a_in.set_inputs(&mut db).to(vec![b.clone()]); + b_in.set_inputs(&mut db).to(vec![c, a.clone()]); + c_in.set_inputs(&mut db).to(vec![ + value(25), + a.clone(), + Input::Successor(Box::new(b.clone())), + ]); + + a.clone().assert_count(&db); + + // next revision, we hit max value instead + c_in.set_inputs(&mut db).to(vec![ + value(243), + a.clone(), + Input::Successor(Box::new(b.clone())), + ]); + + a.clone().assert_bounds(&db); + + // and next revision, we converge + c_in.set_inputs(&mut db) + .to(vec![value(240), a.clone(), b.clone()]); + + a.clone().assert_value(&db, 240); + + // one more revision, without relevant changes + a_in.set_inputs(&mut db).to(vec![b]); + + a.assert_value(&db, 240); +} + +/// a:Ni(b) -> b:Ni(c) -> c:Ni(a) +/// ^ | +/// +---------------------------+ +/// +/// In a cycle with some LOW durability and some HIGH durability inputs, changing a LOW durability +/// input still re-executes the full cycle in the next revision. +#[test] +fn cycle_durability() { + let mut db = DbImpl::new(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinIterate(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinIterate(c_in); + a_in.set_inputs(&mut db) + .with_durability(Durability::LOW) + .to(vec![b.clone()]); + b_in.set_inputs(&mut db) + .with_durability(Durability::HIGH) + .to(vec![c]); + c_in.set_inputs(&mut db) + .with_durability(Durability::HIGH) + .to(vec![a.clone()]); + + a.clone().assert_value(&db, 255); + + // next revision, we converge instead + a_in.set_inputs(&mut db) + .with_durability(Durability::LOW) + .to(vec![value(45), b]); + + a.assert_value(&db, 45); +} + +/// a:Np(v59, b) -> b:Ni(v60, c) -> c:Np(b) +/// ^ | +/// +---------------------+ +/// +/// If nothing in a cycle changed in the new revision, no part of the cycle should re-execute. +#[test] +fn cycle_unchanged() { + let mut db = ExecuteValidateLoggerDatabase::default(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + a_in.set_inputs(&mut db).to(vec![value(59), b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(60), c]); + c_in.set_inputs(&mut db).to(vec![b.clone()]); + + a.clone().assert_value(&db, 59); + b.clone().assert_value(&db, 60); + + db.assert_logs_len(4); + + // next revision, we change only A, which is not part of the cycle and the cycle does not + // depend on. + a_in.set_inputs(&mut db).to(vec![value(45), b.clone()]); + b.assert_value(&db, 60); + + db.assert_logs(expect![[r#" + [ + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", + ]"#]]); + + a.assert_value(&db, 45); +} + +/// a:Np(v59, b) -> b:Ni(v60, c) -> c:Np(d) -> d:Ni(v61, b, e) -> e:Np(d) +/// ^ | ^ | +/// +--------------------------+ +--------------+ +/// +/// If nothing in a nested cycle changed in the new revision, no part of the cycle should +/// re-execute. +#[test] +fn cycle_unchanged_nested() { + let mut db = ExecuteValidateLoggerDatabase::default(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let d_in = Inputs::new(&db, vec![]); + let e_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + let d = Input::MinIterate(d_in); + let e = Input::MinPanic(e_in); + a_in.set_inputs(&mut db).to(vec![value(59), b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(60), c.clone()]); + c_in.set_inputs(&mut db).to(vec![d.clone()]); + d_in.set_inputs(&mut db) + .to(vec![value(61), b.clone(), e.clone()]); + e_in.set_inputs(&mut db).to(vec![d.clone()]); + + a.clone().assert_value(&db, 59); + b.clone().assert_value(&db, 60); + + db.assert_logs_len(10); + + // next revision, we change only A, which is not part of the cycle and the cycle does not + // depend on. + a_in.set_inputs(&mut db).to(vec![value(45), b.clone()]); + b.assert_value(&db, 60); + + db.assert_logs(expect![[r#" + [ + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(4)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", + ]"#]]); + + a.assert_value(&db, 45); +} + +/// +--------------------------------+ +/// | v +/// a:Np(v59, b) -> b:Ni(v60, c) -> c:Np(d, e) -> d:Ni(v61, b, e) -> e:Ni(d) +/// ^ | ^ | +/// +-----------------------------+ +--------------+ +/// +/// If nothing in a nested cycle changed in the new revision, no part of the cycle should +/// re-execute. +#[test_log::test] +fn cycle_unchanged_nested_intertwined() { + // We run this test twice in order to catch some subtly different cases; see below. + for i in 0..1 { + let mut db = ExecuteValidateLoggerDatabase::default(); + let a_in = Inputs::new(&db, vec![]); + let b_in = Inputs::new(&db, vec![]); + let c_in = Inputs::new(&db, vec![]); + let d_in = Inputs::new(&db, vec![]); + let e_in = Inputs::new(&db, vec![]); + let a = Input::MinPanic(a_in); + let b = Input::MinIterate(b_in); + let c = Input::MinPanic(c_in); + let d = Input::MinIterate(d_in); + let e = Input::MinIterate(e_in); + a_in.set_inputs(&mut db).to(vec![value(59), b.clone()]); + b_in.set_inputs(&mut db).to(vec![value(60), c.clone()]); + c_in.set_inputs(&mut db).to(vec![d.clone(), e.clone()]); + d_in.set_inputs(&mut db) + .to(vec![value(61), b.clone(), e.clone()]); + e_in.set_inputs(&mut db).to(vec![d.clone()]); + + a.clone().assert_value(&db, 59); + b.clone().assert_value(&db, 60); + + // First time we run this test, don't fetch c/d/e here; this means they won't get marked + // `verified_final` in R6 (this revision), which will leave us in the next revision (R7) + // with a chain of could-be-provisional memos from the previous revision which should be + // final but were never confirmed as such; this triggers the case in `deep_verify_memo` + // where we need to double-check `validate_provisional` after traversing dependencies. + // + // Second time we run this test, fetch everything in R6, to check the behavior of + // `maybe_changed_after` with all validated-final memos. + if i == 1 { + c.clone().assert_value(&db, 60); + d.clone().assert_value(&db, 60); + e.clone().assert_value(&db, 60); + } + + db.assert_logs_len(16 + i); + + // next revision, we change only A, which is not part of the cycle and the cycle does not + // depend on. + a_in.set_inputs(&mut db).to(vec![value(45), b.clone()]); + b.assert_value(&db, 60); + + db.assert_logs(expect![[r#" + [ + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(4)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(3)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_panic(Id(2)) })", + "salsa_event(DidValidateMemoizedValue { database_key: min_iterate(Id(1)) })", + ]"#]]); + + a.assert_value(&db, 45); + } +} diff --git a/tests/cycle_accumulate.rs b/tests/cycle_accumulate.rs new file mode 100644 index 000000000..484a6a6fc --- /dev/null +++ b/tests/cycle_accumulate.rs @@ -0,0 +1,325 @@ +use std::collections::HashSet; + +mod common; +use common::{LogDatabase, LoggerDatabase}; + +use expect_test::expect; +use salsa::{Accumulator, Setter}; +use test_log::test; + +#[salsa::input] +struct File { + name: String, + dependencies: Vec, + issues: Vec, +} + +#[salsa::accumulator] +struct Diagnostic(#[allow(dead_code)] String); + +#[salsa::tracked(cycle_fn = cycle_fn, cycle_initial = cycle_initial)] +fn check_file(db: &dyn LogDatabase, file: File) -> Vec { + db.push_log(format!( + "check_file(name = {}, issues = {:?})", + file.name(db), + file.issues(db) + )); + + let mut collected_issues = HashSet::::from_iter(file.issues(db).iter().copied()); + + for dep in file.dependencies(db) { + let issues = check_file(db, dep); + collected_issues.extend(issues); + } + + let mut sorted_issues = collected_issues.iter().copied().collect::>(); + sorted_issues.sort(); + + for issue in &sorted_issues { + Diagnostic(format!("file {}: issue {}", file.name(db), issue)).accumulate(db); + } + + sorted_issues +} + +fn cycle_initial(_db: &dyn LogDatabase, _file: File) -> Vec { + vec![] +} + +fn cycle_fn( + _db: &dyn LogDatabase, + _value: &[u32], + _count: u32, + _file: File, +) -> salsa::CycleRecoveryAction> { + salsa::CycleRecoveryAction::Iterate +} + +#[test] +fn accumulate_once() { + let db = LoggerDatabase::default(); + + let file = File::new(&db, "fn".to_string(), vec![], vec![1]); + let diagnostics = check_file::accumulated::(&db, file); + db.assert_logs(expect![[r#" + [ + "check_file(name = fn, issues = [1])", + ]"#]]); + + expect![[r#" + [ + Diagnostic( + "file fn: issue 1", + ), + ]"#]] + .assert_eq(&format!("{:#?}", diagnostics)); +} + +#[test] +fn accumulate_with_dep() { + let db = LoggerDatabase::default(); + + let file_a = File::new(&db, "file_a".to_string(), vec![], vec![1]); + let file_b = File::new(&db, "file_b".to_string(), vec![file_a], vec![2]); + + let diagnostics = check_file::accumulated::(&db, file_b); + db.assert_logs(expect![[r#" + [ + "check_file(name = file_b, issues = [2])", + "check_file(name = file_a, issues = [1])", + ]"#]]); + + expect![[r#" + [ + Diagnostic( + "file file_b: issue 1", + ), + Diagnostic( + "file file_b: issue 2", + ), + Diagnostic( + "file file_a: issue 1", + ), + ]"#]] + .assert_eq(&format!("{:#?}", diagnostics)); +} + +#[test] +fn accumulate_with_cycle() { + let mut db = LoggerDatabase::default(); + + let file_a = File::new(&db, "file_a".to_string(), vec![], vec![1]); + let file_b = File::new(&db, "file_b".to_string(), vec![file_a], vec![2]); + file_a.set_dependencies(&mut db).to(vec![file_b]); + + let diagnostics = check_file::accumulated::(&db, file_b); + db.assert_logs(expect![[r#" + [ + "check_file(name = file_b, issues = [2])", + "check_file(name = file_a, issues = [1])", + "check_file(name = file_b, issues = [2])", + "check_file(name = file_a, issues = [1])", + ]"#]]); + + expect![[r#" + [ + Diagnostic( + "file file_b: issue 1", + ), + Diagnostic( + "file file_b: issue 2", + ), + Diagnostic( + "file file_a: issue 1", + ), + Diagnostic( + "file file_a: issue 2", + ), + ]"#]] + .assert_eq(&format!("{:#?}", diagnostics)); +} + +#[test] +fn accumulate_with_cycle_second_revision() { + let mut db = LoggerDatabase::default(); + + let file_a = File::new(&db, "file_a".to_string(), vec![], vec![1]); + let file_b = File::new(&db, "file_b".to_string(), vec![file_a], vec![2]); + file_a.set_dependencies(&mut db).to(vec![file_b]); + + let diagnostics = check_file::accumulated::(&db, file_b); + db.assert_logs(expect![[r#" + [ + "check_file(name = file_b, issues = [2])", + "check_file(name = file_a, issues = [1])", + "check_file(name = file_b, issues = [2])", + "check_file(name = file_a, issues = [1])", + ]"#]]); + + expect![[r#" + [ + Diagnostic( + "file file_b: issue 1", + ), + Diagnostic( + "file file_b: issue 2", + ), + Diagnostic( + "file file_a: issue 1", + ), + Diagnostic( + "file file_a: issue 2", + ), + ]"#]] + .assert_eq(&format!("{:#?}", diagnostics)); + + file_b.set_issues(&mut db).to(vec![2, 3]); + + let diagnostics = check_file::accumulated::(&db, file_a); + db.assert_logs(expect![[r#" + [ + "check_file(name = file_b, issues = [2, 3])", + "check_file(name = file_a, issues = [1])", + "check_file(name = file_b, issues = [2, 3])", + "check_file(name = file_a, issues = [1])", + "check_file(name = file_b, issues = [2, 3])", + ]"#]]); + + expect![[r#" + [ + Diagnostic( + "file file_a: issue 1", + ), + Diagnostic( + "file file_a: issue 2", + ), + Diagnostic( + "file file_a: issue 3", + ), + Diagnostic( + "file file_b: issue 1", + ), + Diagnostic( + "file file_b: issue 2", + ), + Diagnostic( + "file file_b: issue 3", + ), + ]"#]] + .assert_eq(&format!("{:#?}", diagnostics)); +} + +#[test] +fn accumulate_add_cycle() { + let mut db = LoggerDatabase::default(); + + let file_a = File::new(&db, "file_a".to_string(), vec![], vec![1]); + let file_b = File::new(&db, "file_b".to_string(), vec![file_a], vec![2]); + + let diagnostics = check_file::accumulated::(&db, file_b); + db.assert_logs(expect![[r#" + [ + "check_file(name = file_b, issues = [2])", + "check_file(name = file_a, issues = [1])", + ]"#]]); + + expect![[r#" + [ + Diagnostic( + "file file_b: issue 1", + ), + Diagnostic( + "file file_b: issue 2", + ), + Diagnostic( + "file file_a: issue 1", + ), + ]"#]] + .assert_eq(&format!("{:#?}", diagnostics)); + + file_a.set_dependencies(&mut db).to(vec![file_b]); + + let diagnostics = check_file::accumulated::(&db, file_a); + db.assert_logs(expect![[r#" + [ + "check_file(name = file_a, issues = [1])", + "check_file(name = file_b, issues = [2])", + "check_file(name = file_a, issues = [1])", + "check_file(name = file_b, issues = [2])", + ]"#]]); + + expect![[r#" + [ + Diagnostic( + "file file_a: issue 1", + ), + Diagnostic( + "file file_a: issue 2", + ), + Diagnostic( + "file file_b: issue 1", + ), + Diagnostic( + "file file_b: issue 2", + ), + ]"#]] + .assert_eq(&format!("{:#?}", diagnostics)); +} + +#[test] +fn accumulate_remove_cycle() { + let mut db = LoggerDatabase::default(); + + let file_a = File::new(&db, "file_a".to_string(), vec![], vec![1]); + let file_b = File::new(&db, "file_b".to_string(), vec![file_a], vec![2]); + file_a.set_dependencies(&mut db).to(vec![file_b]); + + let diagnostics = check_file::accumulated::(&db, file_b); + db.assert_logs(expect![[r#" + [ + "check_file(name = file_b, issues = [2])", + "check_file(name = file_a, issues = [1])", + "check_file(name = file_b, issues = [2])", + "check_file(name = file_a, issues = [1])", + ]"#]]); + + expect![[r#" + [ + Diagnostic( + "file file_b: issue 1", + ), + Diagnostic( + "file file_b: issue 2", + ), + Diagnostic( + "file file_a: issue 1", + ), + Diagnostic( + "file file_a: issue 2", + ), + ]"#]] + .assert_eq(&format!("{:#?}", diagnostics)); + + file_a.set_dependencies(&mut db).to(vec![]); + + let diagnostics = check_file::accumulated::(&db, file_b); + db.assert_logs(expect![[r#" + [ + "check_file(name = file_a, issues = [1])", + "check_file(name = file_b, issues = [2])", + ]"#]]); + + expect![[r#" + [ + Diagnostic( + "file file_b: issue 1", + ), + Diagnostic( + "file file_b: issue 2", + ), + Diagnostic( + "file file_a: issue 1", + ), + ]"#]] + .assert_eq(&format!("{:#?}", diagnostics)); +} diff --git a/tests/cycle_initial_call_back_into_cycle.rs b/tests/cycle_initial_call_back_into_cycle.rs new file mode 100644 index 000000000..9dfe39a92 --- /dev/null +++ b/tests/cycle_initial_call_back_into_cycle.rs @@ -0,0 +1,36 @@ +//! Calling back into the same cycle from your cycle initial function will trigger another cycle. + +#[salsa::tracked] +fn initial_value(db: &dyn salsa::Database) -> u32 { + query(db) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query(db: &dyn salsa::Database) -> u32 { + let val = query(db); + if val < 5 { + val + 1 + } else { + val + } +} + +fn cycle_initial(db: &dyn salsa::Database) -> u32 { + initial_value(db) +} + +fn cycle_fn( + _db: &dyn salsa::Database, + _value: &u32, + _count: u32, +) -> salsa::CycleRecoveryAction { + salsa::CycleRecoveryAction::Iterate +} + +#[test_log::test] +#[should_panic(expected = "dependency graph cycle")] +fn the_test() { + let db = salsa::DatabaseImpl::default(); + + query(&db); +} diff --git a/tests/cycle_initial_call_query.rs b/tests/cycle_initial_call_query.rs new file mode 100644 index 000000000..4c52fff27 --- /dev/null +++ b/tests/cycle_initial_call_query.rs @@ -0,0 +1,35 @@ +//! It's possible to call a Salsa query from within a cycle initial fn. + +#[salsa::tracked] +fn initial_value(_db: &dyn salsa::Database) -> u32 { + 0 +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query(db: &dyn salsa::Database) -> u32 { + let val = query(db); + if val < 5 { + val + 1 + } else { + val + } +} + +fn cycle_initial(db: &dyn salsa::Database) -> u32 { + initial_value(db) +} + +fn cycle_fn( + _db: &dyn salsa::Database, + _value: &u32, + _count: u32, +) -> salsa::CycleRecoveryAction { + salsa::CycleRecoveryAction::Iterate +} + +#[test_log::test] +fn the_test() { + let db = salsa::DatabaseImpl::default(); + + assert_eq!(query(&db), 5); +} diff --git a/tests/cycle_output.rs b/tests/cycle_output.rs new file mode 100644 index 000000000..e98106284 --- /dev/null +++ b/tests/cycle_output.rs @@ -0,0 +1,179 @@ +//! Test tracked struct output from a query in a cycle. +mod common; +use common::{HasLogger, LogDatabase, Logger}; +use expect_test::expect; +use salsa::Setter; + +#[salsa::tracked] +struct Output<'db> { + value: u32, +} + +#[salsa::input] +struct InputValue { + value: u32, +} + +#[salsa::tracked] +fn read_value<'db>(db: &'db dyn Db, output: Output<'db>) -> u32 { + output.value(db) +} + +#[salsa::tracked] +fn query_a(db: &dyn Db, input: InputValue) -> u32 { + let val = query_b(db, input); + let output = Output::new(db, val); + let read = read_value(db, output); + assert_eq!(read, val); + query_d(db); + if val > 2 { + val + } else { + val + input.value(db) + } +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query_b(db: &dyn Db, input: InputValue) -> u32 { + query_a(db, input) +} + +fn cycle_initial(_db: &dyn Db, _input: InputValue) -> u32 { + 0 +} + +fn cycle_fn( + _db: &dyn Db, + _value: &u32, + _count: u32, + _input: InputValue, +) -> salsa::CycleRecoveryAction { + salsa::CycleRecoveryAction::Iterate +} + +#[salsa::tracked] +fn query_c(db: &dyn Db, input: InputValue) -> u32 { + input.value(db) +} + +#[salsa::tracked] +fn query_d(db: &dyn Db) -> u32 { + db.get_input().map(|input| input.value(db)).unwrap_or(0) +} + +trait HasOptionInput { + fn get_input(&self) -> Option; + fn set_input(&mut self, input: InputValue); +} + +#[salsa::db] +trait Db: HasOptionInput + salsa::Database {} + +#[salsa::db] +#[derive(Clone, Default)] +struct Database { + storage: salsa::Storage, + logger: Logger, + input: Option, +} + +impl HasLogger for Database { + fn logger(&self) -> &Logger { + &self.logger + } +} + +impl HasOptionInput for Database { + fn get_input(&self) -> Option { + self.input + } + + fn set_input(&mut self, input: InputValue) { + self.input.replace(input); + } +} + +#[salsa::db] +impl salsa::Database for Database { + fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { + let event = event(); + match event.kind { + salsa::EventKind::WillExecute { .. } + | salsa::EventKind::DidValidateMemoizedValue { .. } => { + self.push_log(format!("salsa_event({:?})", event.kind)); + } + _ => {} + } + } +} + +#[salsa::db] +impl Db for Database {} + +#[test_log::test] +fn single_revision() { + let db = Database::default(); + let input = InputValue::new(&db, 1); + + assert_eq!(query_b(&db, input), 3); +} + +#[test_log::test] +fn revalidate_no_changes() { + let mut db = Database::default(); + + let ab_input = InputValue::new(&db, 1); + let c_input = InputValue::new(&db, 10); + assert_eq!(query_c(&db, c_input), 10); + assert_eq!(query_b(&db, ab_input), 3); + + db.assert_logs_len(11); + + // trigger a new revision, but one that doesn't touch the query_a/query_b cycle + c_input.set_value(&mut db).to(20); + + assert_eq!(query_b(&db, ab_input), 3); + + db.assert_logs(expect![[r#" + [ + "salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(401)) })", + "salsa_event(DidValidateMemoizedValue { database_key: query_d(Id(800)) })", + "salsa_event(DidValidateMemoizedValue { database_key: query_b(Id(0)) })", + "salsa_event(DidValidateMemoizedValue { database_key: query_a(Id(0)) })", + "salsa_event(DidValidateMemoizedValue { database_key: query_b(Id(0)) })", + ]"#]]); +} + +#[test_log::test] +fn revalidate_with_change_after_output_read() { + let mut db = Database::default(); + + let ab_input = InputValue::new(&db, 1); + let d_input = InputValue::new(&db, 10); + db.set_input(d_input); + + assert_eq!(query_b(&db, ab_input), 3); + + db.assert_logs_len(10); + + // trigger a new revision that changes the output of query_d + d_input.set_value(&mut db).to(20); + + assert_eq!(query_b(&db, ab_input), 3); + + db.assert_logs(expect![[r#" + [ + "salsa_event(DidValidateMemoizedValue { database_key: read_value(Id(401)) })", + "salsa_event(WillExecute { database_key: query_d(Id(800)) })", + "salsa_event(WillExecute { database_key: query_a(Id(0)) })", + "salsa_event(WillExecute { database_key: read_value(Id(400)) })", + "salsa_event(WillExecute { database_key: query_b(Id(0)) })", + "salsa_event(WillExecute { database_key: query_a(Id(0)) })", + "salsa_event(WillExecute { database_key: query_a(Id(0)) })", + "salsa_event(WillExecute { database_key: read_value(Id(401)) })", + "salsa_event(WillExecute { database_key: query_a(Id(0)) })", + "salsa_event(WillExecute { database_key: read_value(Id(400)) })", + "salsa_event(WillExecute { database_key: query_a(Id(0)) })", + "salsa_event(WillExecute { database_key: read_value(Id(401)) })", + ]"#]]); +} diff --git a/tests/cycle_recovery_call_back_into_cycle.rs b/tests/cycle_recovery_call_back_into_cycle.rs new file mode 100644 index 000000000..a4dc5e250 --- /dev/null +++ b/tests/cycle_recovery_call_back_into_cycle.rs @@ -0,0 +1,43 @@ +//! Calling back into the same cycle from your cycle recovery function _can_ work out, as long as +//! the overall cycle still converges. + +mod common; +use common::{DatabaseWithValue, ValueDatabase}; + +#[salsa::tracked] +fn fallback_value(db: &dyn ValueDatabase) -> u32 { + query(db) + db.get_value() +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query(db: &dyn ValueDatabase) -> u32 { + let val = query(db); + if val < 5 { + val + 1 + } else { + val + } +} + +fn cycle_initial(_db: &dyn ValueDatabase) -> u32 { + 0 +} + +fn cycle_fn(db: &dyn ValueDatabase, _value: &u32, _count: u32) -> salsa::CycleRecoveryAction { + salsa::CycleRecoveryAction::Fallback(fallback_value(db)) +} + +#[test] +fn converges() { + let db = DatabaseWithValue::new(10); + + assert_eq!(query(&db), 10); +} + +#[test] +#[should_panic(expected = "fallback did not converge")] +fn diverges() { + let db = DatabaseWithValue::new(3); + + query(&db); +} diff --git a/tests/cycle_recovery_call_query.rs b/tests/cycle_recovery_call_query.rs new file mode 100644 index 000000000..a768017c8 --- /dev/null +++ b/tests/cycle_recovery_call_query.rs @@ -0,0 +1,35 @@ +//! It's possible to call a Salsa query from within a cycle recovery fn. + +#[salsa::tracked] +fn fallback_value(_db: &dyn salsa::Database) -> u32 { + 10 +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query(db: &dyn salsa::Database) -> u32 { + let val = query(db); + if val < 5 { + val + 1 + } else { + val + } +} + +fn cycle_initial(_db: &dyn salsa::Database) -> u32 { + 0 +} + +fn cycle_fn( + db: &dyn salsa::Database, + _value: &u32, + _count: u32, +) -> salsa::CycleRecoveryAction { + salsa::CycleRecoveryAction::Fallback(fallback_value(db)) +} + +#[test_log::test] +fn the_test() { + let db = salsa::DatabaseImpl::default(); + + assert_eq!(query(&db), 10); +} diff --git a/tests/cycle_regression_455.rs b/tests/cycle_regression_455.rs new file mode 100644 index 000000000..5beff8d3d --- /dev/null +++ b/tests/cycle_regression_455.rs @@ -0,0 +1,55 @@ +use salsa::{Database, Setter}; + +#[salsa::tracked] +fn memoized(db: &dyn Database, input: MyInput) -> u32 { + memoized_a(db, MyTracked::new(db, input.field(db))) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn memoized_a<'db>(db: &'db dyn Database, tracked: MyTracked<'db>) -> u32 { + MyTracked::new(db, 0); + memoized_b(db, tracked) +} + +fn cycle_fn<'db>( + _db: &'db dyn Database, + _value: &u32, + _count: u32, + _input: MyTracked<'db>, +) -> salsa::CycleRecoveryAction { + salsa::CycleRecoveryAction::Iterate +} + +fn cycle_initial(_db: &dyn Database, _input: MyTracked) -> u32 { + 0 +} + +#[salsa::tracked] +fn memoized_b<'db>(db: &'db dyn Database, tracked: MyTracked<'db>) -> u32 { + let incr = tracked.field(db); + let a = memoized_a(db, tracked); + if a > 8 { + a + } else { + a + incr + } +} + +#[salsa::input] +struct MyInput { + field: u32, +} + +#[salsa::tracked] +struct MyTracked<'db> { + field: u32, +} + +#[test] +fn cycle_memoized() { + let mut db = salsa::DatabaseImpl::new(); + let input = MyInput::new(&db, 2); + assert_eq!(memoized(&db, input), 10); + input.set_field(&mut db).to(3); + assert_eq!(memoized(&db, input), 9); +} diff --git a/tests/cycles.rs b/tests/cycles.rs deleted file mode 100644 index be32beb8e..000000000 --- a/tests/cycles.rs +++ /dev/null @@ -1,437 +0,0 @@ -#![allow(warnings)] - -use std::panic::{RefUnwindSafe, UnwindSafe}; - -use expect_test::expect; -use salsa::DatabaseImpl; -use salsa::Durability; - -// Axes: -// -// Threading -// * Intra-thread -// * Cross-thread -- part of cycle is on one thread, part on another -// -// Recovery strategies: -// * Panic -// * Fallback -// * Mixed -- multiple strategies within cycle participants -// -// Across revisions: -// * N/A -- only one revision -// * Present in new revision, not old -// * Present in old revision, not new -// * Present in both revisions -// -// Dependencies -// * Tracked -// * Untracked -- cycle participant(s) contain untracked reads -// -// Layers -// * Direct -- cycle participant is directly invoked from test -// * Indirect -- invoked a query that invokes the cycle -// -// -// | Thread | Recovery | Old, New | Dep style | Layers | Test Name | -// | ------ | -------- | -------- | --------- | ------ | --------- | -// | Intra | Panic | N/A | Tracked | direct | cycle_memoized | -// | Intra | Panic | N/A | Untracked | direct | cycle_volatile | -// | Intra | Fallback | N/A | Tracked | direct | cycle_cycle | -// | Intra | Fallback | N/A | Tracked | indirect | inner_cycle | -// | Intra | Fallback | Both | Tracked | direct | cycle_revalidate | -// | Intra | Fallback | New | Tracked | direct | cycle_appears | -// | Intra | Fallback | Old | Tracked | direct | cycle_disappears | -// | Intra | Fallback | Old | Tracked | direct | cycle_disappears_durability | -// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_1 | -// | Intra | Mixed | N/A | Tracked | direct | cycle_mixed_2 | -// | Cross | Panic | N/A | Tracked | both | parallel/parallel_cycle_none_recover.rs | -// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_one_recover.rs | -// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_mid_recover.rs | -// | Cross | Fallback | N/A | Tracked | both | parallel/parallel_cycle_all_recover.rs | - -#[derive(PartialEq, Eq, Hash, Clone, Debug, Update)] -struct Error { - cycle: Vec, -} - -use salsa::Database as Db; -use salsa::Setter; -use salsa::Update; - -#[salsa::input] -struct MyInput {} - -#[salsa::tracked] -fn memoized_a(db: &dyn Db, input: MyInput) { - memoized_b(db, input) -} - -#[salsa::tracked] -fn memoized_b(db: &dyn Db, input: MyInput) { - memoized_a(db, input) -} - -#[salsa::tracked] -fn volatile_a(db: &dyn Db, input: MyInput) { - db.report_untracked_read(); - volatile_b(db, input) -} - -#[salsa::tracked] -fn volatile_b(db: &dyn Db, input: MyInput) { - db.report_untracked_read(); - volatile_a(db, input) -} - -/// The queries A, B, and C in `Database` can be configured -/// to invoke one another in arbitrary ways using this -/// enum. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -enum CycleQuery { - None, - A, - B, - C, - AthenC, -} - -#[salsa::input] -struct ABC { - a: CycleQuery, - b: CycleQuery, - c: CycleQuery, -} - -impl CycleQuery { - fn invoke(self, db: &dyn Db, abc: ABC) -> Result<(), Error> { - match self { - CycleQuery::A => cycle_a(db, abc), - CycleQuery::B => cycle_b(db, abc), - CycleQuery::C => cycle_c(db, abc), - CycleQuery::AthenC => { - let _ = cycle_a(db, abc); - cycle_c(db, abc) - } - CycleQuery::None => Ok(()), - } - } -} - -#[salsa::tracked(recovery_fn=recover_a)] -fn cycle_a(db: &dyn Db, abc: ABC) -> Result<(), Error> { - abc.a(db).invoke(db, abc) -} - -fn recover_a(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { - Err(Error { - cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), - }) -} - -#[salsa::tracked(recovery_fn=recover_b)] -fn cycle_b(db: &dyn Db, abc: ABC) -> Result<(), Error> { - abc.b(db).invoke(db, abc) -} - -fn recover_b(db: &dyn Db, cycle: &salsa::Cycle, abc: ABC) -> Result<(), Error> { - Err(Error { - cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), - }) -} - -#[salsa::tracked] -fn cycle_c(db: &dyn Db, abc: ABC) -> Result<(), Error> { - abc.c(db).invoke(db, abc) -} - -#[track_caller] -fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle { - let v = std::panic::catch_unwind(f); - if let Err(d) = &v { - if let Some(cycle) = d.downcast_ref::() { - return cycle.clone(); - } - } - panic!("unexpected value: {:?}", v) -} - -#[test] -fn cycle_memoized() { - salsa::DatabaseImpl::new().attach(|db| { - let input = MyInput::new(db); - let cycle = extract_cycle(|| memoized_a(db, input)); - let expected = expect![[r#" - [ - memoized_a(Id(0)), - memoized_b(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&cycle.all_participants(db)); - }) -} - -#[test] -fn cycle_volatile() { - salsa::DatabaseImpl::new().attach(|db| { - let input = MyInput::new(db); - let cycle = extract_cycle(|| volatile_a(db, input)); - let expected = expect![[r#" - [ - volatile_a(Id(0)), - volatile_b(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&cycle.all_participants(db)); - }); -} - -#[test] -fn expect_cycle() { - // A --> B - // ^ | - // +-----+ - - salsa::DatabaseImpl::new().attach(|db| { - let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(db, abc).is_err()); - }) -} - -#[test] -fn inner_cycle() { - // A --> B <-- C - // ^ | - // +-----+ - salsa::DatabaseImpl::new().attach(|db| { - let abc = ABC::new(db, CycleQuery::B, CycleQuery::A, CycleQuery::B); - let err = cycle_c(db, abc); - assert!(err.is_err()); - let expected = expect![[r#" - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ] - "#]]; - expected.assert_debug_eq(&err.unwrap_err().cycle); - }) -} - -#[test] -fn cycle_revalidate() { - // A --> B - // ^ | - // +-----+ - let mut db = salsa::DatabaseImpl::new(); - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(&db, abc).is_err()); - abc.set_b(&mut db).to(CycleQuery::A); // same value as default - assert!(cycle_a(&db, abc).is_err()); -} - -#[test] -fn cycle_recovery_unchanged_twice() { - // A --> B - // ^ | - // +-----+ - let mut db = salsa::DatabaseImpl::new(); - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(&db, abc).is_err()); - - abc.set_c(&mut db).to(CycleQuery::A); // force new revision - assert!(cycle_a(&db, abc).is_err()); -} - -#[test] -fn cycle_appears() { - let mut db = salsa::DatabaseImpl::new(); - // A --> B - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::None, CycleQuery::None); - assert!(cycle_a(&db, abc).is_ok()); - - // A --> B - // ^ | - // +-----+ - abc.set_b(&mut db).to(CycleQuery::A); - assert!(cycle_a(&db, abc).is_err()); -} - -#[test] -fn cycle_disappears() { - let mut db = salsa::DatabaseImpl::new(); - - // A --> B - // ^ | - // +-----+ - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - assert!(cycle_a(&db, abc).is_err()); - - // A --> B - abc.set_b(&mut db).to(CycleQuery::None); - assert!(cycle_a(&db, abc).is_ok()); -} - -/// A variant on `cycle_disappears` in which the values of -/// `a` and `b` are set with durability values. -/// If we are not careful, this could cause us to overlook -/// the fact that the cycle will no longer occur. -#[test] -fn cycle_disappears_durability() { - let mut db = salsa::DatabaseImpl::new(); - let abc = ABC::new( - &mut db, - CycleQuery::None, - CycleQuery::None, - CycleQuery::None, - ); - abc.set_a(&mut db) - .with_durability(Durability::LOW) - .to(CycleQuery::B); - abc.set_b(&mut db) - .with_durability(Durability::HIGH) - .to(CycleQuery::A); - - assert!(cycle_a(&db, abc).is_err()); - - // At this point, `a` read `LOW` input, and `b` read `HIGH` input. However, - // because `b` participates in the same cycle as `a`, its final durability - // should be `LOW`. - // - // Check that setting a `LOW` input causes us to re-execute `b` query, and - // observe that the cycle goes away. - abc.set_a(&mut db) - .with_durability(Durability::LOW) - .to(CycleQuery::None); - - assert!(cycle_b(&mut db, abc).is_ok()); -} - -#[test] -fn cycle_mixed_1() { - salsa::DatabaseImpl::new().attach(|db| { - // A --> B <-- C - // | ^ - // +-----+ - let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::B); - - let expected = expect![[r#" - [ - "cycle_b(Id(0))", - "cycle_c(Id(0))", - ] - "#]]; - expected.assert_debug_eq(&cycle_c(db, abc).unwrap_err().cycle); - }) -} - -#[test] -fn cycle_mixed_2() { - salsa::DatabaseImpl::new().attach(|db| { - // Configuration: - // - // A --> B --> C - // ^ | - // +-----------+ - let abc = ABC::new(db, CycleQuery::B, CycleQuery::C, CycleQuery::A); - let expected = expect![[r#" - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - "cycle_c(Id(0))", - ] - "#]]; - expected.assert_debug_eq(&cycle_a(db, abc).unwrap_err().cycle); - }) -} - -#[test] -fn cycle_deterministic_order() { - // No matter whether we start from A or B, we get the same set of participants: - let f = || { - let mut db = salsa::DatabaseImpl::new(); - - // A --> B - // ^ | - // +-----+ - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); - (db, abc) - }; - let (db, abc) = f(); - let a = cycle_a(&db, abc); - let (db, abc) = f(); - let b = cycle_b(&db, abc); - let expected = expect![[r#" - ( - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - ) - "#]]; - expected.assert_debug_eq(&(a.unwrap_err().cycle, b.unwrap_err().cycle)); -} - -#[test] -fn cycle_multiple() { - // No matter whether we start from A or B, we get the same set of participants: - let mut db = salsa::DatabaseImpl::new(); - - // Configuration: - // - // A --> B <-- C - // ^ | ^ - // +-----+ | - // | | - // +-----+ - // - // Here, conceptually, B encounters a cycle with A and then - // recovers. - let abc = ABC::new(&db, CycleQuery::B, CycleQuery::AthenC, CycleQuery::A); - - let c = cycle_c(&db, abc); - let b = cycle_b(&db, abc); - let a = cycle_a(&db, abc); - let expected = expect![[r#" - ( - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - [ - "cycle_a(Id(0))", - "cycle_b(Id(0))", - ], - ) - "#]]; - expected.assert_debug_eq(&( - c.unwrap_err().cycle, - b.unwrap_err().cycle, - a.unwrap_err().cycle, - )); -} - -#[test] -fn cycle_recovery_set_but_not_participating() { - salsa::DatabaseImpl::new().attach(|db| { - // A --> C -+ - // ^ | - // +--+ - let abc = ABC::new(db, CycleQuery::C, CycleQuery::None, CycleQuery::C); - - // Here we expect C to panic and A not to recover: - let r = extract_cycle(|| drop(cycle_a(db, abc))); - let expected = expect![[r#" - [ - cycle_c(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&r.all_participants(db)); - }) -} diff --git a/tests/dataflow.rs b/tests/dataflow.rs new file mode 100644 index 000000000..b5784d8e2 --- /dev/null +++ b/tests/dataflow.rs @@ -0,0 +1,246 @@ +//! Test case for fixpoint iteration cycle resolution. +//! +//! This test case is intended to simulate a (very simplified) version of a real dataflow analysis +//! using fixpoint iteration. +use salsa::{CycleRecoveryAction, Database as Db, Setter}; +use std::collections::BTreeSet; +use std::iter::IntoIterator; + +/// A Use of a symbol. +#[salsa::input] +struct Use { + reaching_definitions: Vec, +} + +/// A Definition of a symbol, either of the form `base + increment` or `0 + increment`. +#[salsa::input] +struct Definition { + base: Option, + increment: usize, +} + +#[derive(Eq, PartialEq, Clone, Debug, salsa::Update)] +enum Type { + Bottom, + Values(Box<[usize]>), + Top, +} + +impl Type { + fn join(tys: impl IntoIterator) -> Type { + let mut result = Type::Bottom; + for ty in tys.into_iter() { + result = match (result, ty) { + (result, Type::Bottom) => result, + (_, Type::Top) => Type::Top, + (Type::Top, _) => Type::Top, + (Type::Bottom, ty) => ty, + (Type::Values(a_ints), Type::Values(b_ints)) => { + let mut set = BTreeSet::new(); + set.extend(a_ints); + set.extend(b_ints); + Type::Values(set.into_iter().collect()) + } + } + } + result + } +} + +#[salsa::tracked(cycle_fn=use_cycle_recover, cycle_initial=use_cycle_initial)] +fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type { + let defs = u.reaching_definitions(db); + match defs[..] { + [] => Type::Bottom, + [def] => infer_definition(db, def), + _ => Type::join(defs.iter().map(|&def| infer_definition(db, def))), + } +} + +#[salsa::tracked(cycle_fn=def_cycle_recover, cycle_initial=def_cycle_initial)] +fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { + let increment_ty = Type::Values(Box::from([def.increment(db)])); + if let Some(base) = def.base(db) { + let base_ty = infer_use(db, base); + add(&base_ty, &increment_ty) + } else { + increment_ty + } +} + +fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type { + Type::Bottom +} + +fn def_cycle_recover( + _db: &dyn Db, + value: &Type, + count: u32, + _def: Definition, +) -> CycleRecoveryAction { + cycle_recover(value, count) +} + +fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type { + Type::Bottom +} + +fn use_cycle_recover( + _db: &dyn Db, + value: &Type, + count: u32, + _use: Use, +) -> CycleRecoveryAction { + cycle_recover(value, count) +} + +fn cycle_recover(value: &Type, count: u32) -> CycleRecoveryAction { + match value { + Type::Bottom => CycleRecoveryAction::Iterate, + Type::Values(_) => { + if count > 4 { + CycleRecoveryAction::Fallback(Type::Top) + } else { + CycleRecoveryAction::Iterate + } + } + Type::Top => CycleRecoveryAction::Iterate, + } +} + +fn add(a: &Type, b: &Type) -> Type { + match (a, b) { + (Type::Bottom, _) | (_, Type::Bottom) => Type::Bottom, + (Type::Top, _) | (_, Type::Top) => Type::Top, + (Type::Values(a_ints), Type::Values(b_ints)) => { + let mut set = BTreeSet::new(); + set.extend( + a_ints + .into_iter() + .flat_map(|a| b_ints.into_iter().map(move |b| a + b)), + ); + Type::Values(set.into_iter().collect()) + } + } +} + +/// x = 1 +#[test] +fn simple() { + let db = salsa::DatabaseImpl::new(); + + let def = Definition::new(&db, None, 1); + let u = Use::new(&db, vec![def]); + + let ty = infer_use(&db, u); + + assert_eq!(ty, Type::Values(Box::from([1]))); +} + +/// x = 1 if flag else 2 +#[test] +fn union() { + let db = salsa::DatabaseImpl::new(); + + let def1 = Definition::new(&db, None, 1); + let def2 = Definition::new(&db, None, 2); + let u = Use::new(&db, vec![def1, def2]); + + let ty = infer_use(&db, u); + + assert_eq!(ty, Type::Values(Box::from([1, 2]))); +} + +/// x = 1 if flag else 2; y = x + 1 +#[test] +fn union_add() { + let db = salsa::DatabaseImpl::new(); + + let x1 = Definition::new(&db, None, 1); + let x2 = Definition::new(&db, None, 2); + let x_use = Use::new(&db, vec![x1, x2]); + let y_def = Definition::new(&db, Some(x_use), 1); + let y_use = Use::new(&db, vec![y_def]); + + let ty = infer_use(&db, y_use); + + assert_eq!(ty, Type::Values(Box::from([2, 3]))); +} + +/// x = 1; loop { x = x + 0 } +#[test] +fn cycle_converges_then_diverges() { + let mut db = salsa::DatabaseImpl::new(); + + let def1 = Definition::new(&db, None, 1); + let def2 = Definition::new(&db, None, 0); + let u = Use::new(&db, vec![def1, def2]); + def2.set_base(&mut db).to(Some(u)); + + let ty = infer_use(&db, u); + + // Loop converges on 1 + assert_eq!(ty, Type::Values(Box::from([1]))); + + // Set the increment on x from 0 to 1 + let new_increment = 1; + def2.set_increment(&mut db).to(new_increment); + + // Now the loop diverges and we fall back to Top + assert_eq!(infer_use(&db, u), Type::Top); +} + +/// x = 1; loop { x = x + 1 } +#[test] +fn cycle_diverges_then_converges() { + let mut db = salsa::DatabaseImpl::new(); + + let def1 = Definition::new(&db, None, 1); + let def2 = Definition::new(&db, None, 1); + let u = Use::new(&db, vec![def1, def2]); + def2.set_base(&mut db).to(Some(u)); + + let ty = infer_use(&db, u); + + // Loop diverges. Cut it off and fallback to Type::Top + assert_eq!(ty, Type::Top); + + // Set the increment from 1 to 0. + def2.set_increment(&mut db).to(0); + + // Now the loop converges on 1. + assert_eq!(infer_use(&db, u), Type::Values(Box::from([1]))); +} + +/// x = 0; y = 0; loop { x = y + 0; y = x + 0 } +#[test_log::test] +fn multi_symbol_cycle_converges_then_diverges() { + let mut db = salsa::DatabaseImpl::new(); + + let defx0 = Definition::new(&db, None, 0); + let defy0 = Definition::new(&db, None, 0); + let defx1 = Definition::new(&db, None, 0); + let defy1 = Definition::new(&db, None, 0); + let use_x = Use::new(&db, vec![defx0, defx1]); + let use_y = Use::new(&db, vec![defy0, defy1]); + defx1.set_base(&mut db).to(Some(use_y)); + defy1.set_base(&mut db).to(Some(use_x)); + + // Both symbols converge on 0 + assert_eq!(infer_use(&db, use_x), Type::Values(Box::from([0]))); + assert_eq!(infer_use(&db, use_y), Type::Values(Box::from([0]))); + + // Set the increment on x to 0. + defx1.set_increment(&mut db).to(0); + + // Both symbols still converge on 0. + assert_eq!(infer_use(&db, use_x), Type::Values(Box::from([0]))); + assert_eq!(infer_use(&db, use_y), Type::Values(Box::from([0]))); + + // Set the increment on x from 0 to 1. + defx1.set_increment(&mut db).to(1); + + // Now the loop diverges and we fall back to Top. + assert_eq!(infer_use(&db, use_x), Type::Top); + assert_eq!(infer_use(&db, use_y), Type::Top); +} diff --git a/tests/parallel/cycle_a_t1_b_t2.rs b/tests/parallel/cycle_a_t1_b_t2.rs new file mode 100644 index 000000000..aa0b84845 --- /dev/null +++ b/tests/parallel/cycle_a_t1_b_t2.rs @@ -0,0 +1,74 @@ +//! Test a specific cycle scenario: +//! +//! ```text +//! Thread T1 Thread T2 +//! --------- --------- +//! | | +//! v | +//! query_a() | +//! ^ | v +//! | +------------> query_b() +//! | | +//! +--------------------+ +//! ``` + +use salsa::CycleRecoveryAction; + +use crate::setup::{Knobs, KnobsDatabase}; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MAX: CycleValue = CycleValue(3); + +// Signal 1: T1 has entered `query_a` +// Signal 2: T2 has entered `query_b` + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_a(db: &dyn KnobsDatabase) -> CycleValue { + db.signal(1); + + // Wait for Thread T2 to enter `query_b` before we continue. + db.wait_for(2); + + query_b(db) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_b(db: &dyn KnobsDatabase) -> CycleValue { + // Wait for Thread T1 to enter `query_a` before we continue. + db.wait_for(1); + + db.signal(2); + + let a_value = query_a(db); + CycleValue(a_value.0 + 1).min(MAX) +} + +fn cycle_fn( + _db: &dyn KnobsDatabase, + _value: &CycleValue, + _count: u32, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +fn initial(_db: &dyn KnobsDatabase) -> CycleValue { + MIN +} + +#[test_log::test] +fn the_test() { + std::thread::scope(|scope| { + let db_t1 = Knobs::default(); + let db_t2 = db_t1.clone(); + + let t1 = scope.spawn(move || query_a(&db_t1)); + let t2 = scope.spawn(move || query_b(&db_t2)); + + let (r_t1, r_t2) = (t1.join().unwrap(), t2.join().unwrap()); + + assert_eq!((r_t1, r_t2), (MAX, MAX)); + }); +} diff --git a/tests/parallel/cycle_ab_peeping_c.rs b/tests/parallel/cycle_ab_peeping_c.rs new file mode 100644 index 000000000..1c8233fab --- /dev/null +++ b/tests/parallel/cycle_ab_peeping_c.rs @@ -0,0 +1,78 @@ +//! Test a specific cycle scenario: +//! +//! Thread T1 calls A which calls B which calls A. +//! +//! Thread T2 calls C which calls B. +//! +//! The trick is that the call from Thread T2 comes before B has reached a fixed point. +//! We want to be sure that C sees the final value (and blocks until it is complete). + +use salsa::CycleRecoveryAction; + +use crate::setup::{Knobs, KnobsDatabase}; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MID: CycleValue = CycleValue(11); +const MAX: CycleValue = CycleValue(22); + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query_a(db: &dyn KnobsDatabase) -> CycleValue { + let b_value = query_b(db); + + // When we reach the mid point, signal stage 1 (unblocking T2) + // and then wait for T2 to signal stage 2. + if b_value == MID { + db.signal(1); + db.wait_for(2); + } + + b_value +} + +fn cycle_fn( + _db: &dyn KnobsDatabase, + _value: &CycleValue, + _count: u32, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +fn cycle_initial(_db: &dyn KnobsDatabase) -> CycleValue { + MIN +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)] +fn query_b(db: &dyn KnobsDatabase) -> CycleValue { + let a_value = query_a(db); + + CycleValue(a_value.0 + 1).min(MAX) +} + +#[salsa::tracked] +fn query_c(db: &dyn KnobsDatabase) -> CycleValue { + // Wait until T1 has reached MID then execute `query_b`. + // This should block and (due to the configuration on our database) signal stage 2. + db.wait_for(1); + + query_b(db) +} + +#[test] +fn the_test() { + std::thread::scope(|scope| { + let db_t1 = Knobs::default(); + + let db_t2 = db_t1.clone(); + db_t2.signal_on_will_block(2); + + let t1 = scope.spawn(move || query_a(&db_t1)); + let t2 = scope.spawn(move || query_c(&db_t2)); + + let (r_t1, r_t2) = (t1.join().unwrap(), t2.join().unwrap()); + + assert_eq!((r_t1, r_t2), (MAX, MAX)); + }); +} diff --git a/tests/parallel/cycle_nested_three_threads.rs b/tests/parallel/cycle_nested_three_threads.rs new file mode 100644 index 000000000..f0ff0e128 --- /dev/null +++ b/tests/parallel/cycle_nested_three_threads.rs @@ -0,0 +1,89 @@ +//! Test a nested-cycle scenario across three threads: +//! +//! ```text +//! Thread T1 Thread T2 Thread T3 +//! --------- --------- --------- +//! | | | +//! v | | +//! query_a() | | +//! ^ | v | +//! | +------------> query_b() | +//! | ^ | v +//! | | +------------> query_c() +//! | | | +//! +------------------+--------------------+ +//! +//! ``` + +use salsa::CycleRecoveryAction; + +use crate::setup::{Knobs, KnobsDatabase}; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)] +struct CycleValue(u32); + +const MIN: CycleValue = CycleValue(0); +const MAX: CycleValue = CycleValue(3); + +// Signal 1: T1 has entered `query_a` +// Signal 2: T2 has entered `query_b` +// Signal 3: T3 has entered `query_c` + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_a(db: &dyn KnobsDatabase) -> CycleValue { + db.signal(1); + db.wait_for(3); + + query_b(db) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_b(db: &dyn KnobsDatabase) -> CycleValue { + db.wait_for(1); + db.signal(2); + db.wait_for(3); + + let c_value = query_c(db); + CycleValue(c_value.0 + 1).min(MAX) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_c(db: &dyn KnobsDatabase) -> CycleValue { + db.wait_for(2); + db.signal(3); + + let a_value = query_a(db); + let b_value = query_b(db); + CycleValue(a_value.0.max(b_value.0)) +} + +fn cycle_fn( + _db: &dyn KnobsDatabase, + _value: &CycleValue, + _count: u32, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +fn initial(_db: &dyn KnobsDatabase) -> CycleValue { + MIN +} + +#[test_log::test] +fn the_test() { + std::thread::scope(|scope| { + let db_t1 = Knobs::default(); + let db_t2 = db_t1.clone(); + let db_t3 = db_t1.clone(); + + let t1 = scope.spawn(move || query_a(&db_t1)); + let t2 = scope.spawn(move || query_b(&db_t2)); + let t3 = scope.spawn(move || query_c(&db_t3)); + + let r_t1 = t1.join().unwrap(); + let r_t2 = t2.join().unwrap(); + let r_t3 = t3.join().unwrap(); + + assert_eq!((r_t1, r_t2, r_t3), (MAX, MAX, MAX)); + }); +} diff --git a/tests/parallel/cycle_panic.rs b/tests/parallel/cycle_panic.rs new file mode 100644 index 000000000..1c4825549 --- /dev/null +++ b/tests/parallel/cycle_panic.rs @@ -0,0 +1,42 @@ +//! Test for panic in cycle recovery function, in cross-thread cycle. +use crate::setup::{Knobs, KnobsDatabase}; + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_a(db: &dyn KnobsDatabase) -> u32 { + db.signal(1); + db.wait_for(2); + query_b(db) +} + +#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)] +fn query_b(db: &dyn KnobsDatabase) -> u32 { + db.wait_for(1); + db.signal(2); + query_a(db) + 1 +} + +fn cycle_fn(_db: &dyn KnobsDatabase, _value: &u32, _count: u32) -> salsa::CycleRecoveryAction { + panic!("cancel!") +} + +fn initial(_db: &dyn KnobsDatabase) -> u32 { + 0 +} + +#[test] +fn execute() { + let db = Knobs::default(); + + std::thread::scope(|scope| { + let db_t1 = db.clone(); + let t1 = scope.spawn(move || query_a(&db_t1)); + + let db_t2 = db.clone(); + let t2 = scope.spawn(move || query_b(&db_t2)); + + // The main thing here is that we don't deadlock. + let (r1, r2) = (t1.join(), t2.join()); + assert!(r1.is_err()); + assert!(r2.is_err()); + }); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index e01e46546..cf02f64ac 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -1,9 +1,9 @@ mod setup; +mod cycle_a_t1_b_t2; +mod cycle_ab_peeping_c; +mod cycle_nested_three_threads; +mod cycle_panic; mod parallel_cancellation; -mod parallel_cycle_all_recover; -mod parallel_cycle_mid_recover; -mod parallel_cycle_none_recover; -mod parallel_cycle_one_recover; mod parallel_map; mod signal; diff --git a/tests/parallel/parallel_cancellation.rs b/tests/parallel/parallel_cancellation.rs index 6d7993584..2b5a05fbb 100644 --- a/tests/parallel/parallel_cancellation.rs +++ b/tests/parallel/parallel_cancellation.rs @@ -1,6 +1,4 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. +//! Test for thread cancellation. use salsa::Cancelled; use salsa::Setter; diff --git a/tests/parallel/parallel_cycle_all_recover.rs b/tests/parallel/parallel_cycle_all_recover.rs deleted file mode 100644 index 08858ef5d..000000000 --- a/tests/parallel/parallel_cycle_all_recover.rs +++ /dev/null @@ -1,104 +0,0 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::Knobs; -use crate::setup::KnobsDatabase; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked(recovery_fn = recover_a1)] -pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.signal(1); - db.wait_for(2); - - a2(db, input) -} - -fn recover_a1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_a1"); - key.field(db) * 10 + 1 -} - -#[salsa::tracked(recovery_fn=recover_a2)] -pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - b1(db, input) -} - -fn recover_a2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_a2"); - key.field(db) * 10 + 2 -} - -#[salsa::tracked(recovery_fn=recover_b1)] -pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.wait_for(1); - db.signal(2); - - // Wait for thread A to block on this thread - db.wait_for(3); - b2(db, input) -} - -fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b1"); - key.field(db) * 20 + 1 -} - -#[salsa::tracked(recovery_fn=recover_b2)] -pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - a1(db, input) -} - -fn recover_b2(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b2"); - key.field(db) * 20 + 2 -} - -// Recover cycle test: -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | signal stage 2 -// (unblocked) wait for stage 3 (blocks) -// a2 | -// b1 (blocks -> stage 3) | -// | (unblocked) -// | b2 -// | a1 (cycle detected, recovers) -// | b2 completes, recovers -// | b1 completes, recovers -// a2 sees cycle, recovers -// a1 completes, recovers - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block(3); - move || a1(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - move || b1(&db, input) - }); - - assert_eq!(thread_a.join().unwrap(), 11); - assert_eq!(thread_b.join().unwrap(), 21); -} diff --git a/tests/parallel/parallel_cycle_mid_recover.rs b/tests/parallel/parallel_cycle_mid_recover.rs deleted file mode 100644 index c41ed32d1..000000000 --- a/tests/parallel/parallel_cycle_mid_recover.rs +++ /dev/null @@ -1,102 +0,0 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::{Knobs, KnobsDatabase}; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked] -pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // tell thread b we have started - db.signal(1); - - // wait for thread b to block on a1 - db.wait_for(2); - - a2(db, input) -} - -#[salsa::tracked] -pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // create the cycle - b1(db, input) -} - -#[salsa::tracked(recovery_fn=recover_b1)] -pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // wait for thread a to have started - db.wait_for(1); - b2(db, input) -} - -fn recover_b1(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b1"); - key.field(db) * 20 + 2 -} - -#[salsa::tracked] -pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // will encounter a cycle but recover - b3(db, input); - b1(db, input); // hasn't recovered yet - 0 -} - -#[salsa::tracked(recovery_fn=recover_b3)] -pub(crate) fn b3(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // will block on thread a, signaling stage 2 - a1(db, input) -} - -fn recover_b3(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover_b3"); - key.field(db) * 200 + 2 -} - -// Recover cycle test: -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | | -// | b2 -// | b3 -// | a1 (blocks -> stage 2) -// (unblocked) | -// a2 (cycle detected) | -// b3 recovers -// b2 resumes -// b1 recovers - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - move || a1(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block(3); - move || b1(&db, input) - }); - - // We expect that the recovery function yields - // `1 * 20 + 2`, which is returned (and forwarded) - // to b1, and from there to a2 and a1. - assert_eq!(thread_a.join().unwrap(), 22); - assert_eq!(thread_b.join().unwrap(), 22); -} diff --git a/tests/parallel/parallel_cycle_none_recover.rs b/tests/parallel/parallel_cycle_none_recover.rs deleted file mode 100644 index f1f0ee91e..000000000 --- a/tests/parallel/parallel_cycle_none_recover.rs +++ /dev/null @@ -1,78 +0,0 @@ -//! Test a cycle where no queries recover that occurs across threads. -//! See the `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::Knobs; -use crate::setup::KnobsDatabase; -use expect_test::expect; -use salsa::Database; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked] -pub(crate) fn a(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.signal(1); - db.wait_for(2); - - b(db, input) -} - -#[salsa::tracked] -pub(crate) fn b(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.wait_for(1); - db.signal(2); - - // Wait for thread A to block on this thread - db.wait_for(3); - - // Now try to execute A - a(db, input) -} - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, -1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block(3); - move || a(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - move || b(&db, input) - }); - - // We expect B to panic because it detects a cycle (it is the one that calls A, ultimately). - // Right now, it panics with a string. - let err_b = thread_b.join().unwrap_err(); - db.attach(|_| { - if let Some(c) = err_b.downcast_ref::() { - let expected = expect![[r#" - [ - a(Id(0)), - b(Id(0)), - ] - "#]]; - expected.assert_debug_eq(&c.all_participants(&db)); - } else { - panic!("b failed in an unexpected way: {:?}", err_b); - } - }); - - // We expect A to propagate a panic, which causes us to use the sentinel - // type `Canceled`. - assert!(thread_a - .join() - .unwrap_err() - .downcast_ref::() - .is_some()); -} diff --git a/tests/parallel/parallel_cycle_one_recover.rs b/tests/parallel/parallel_cycle_one_recover.rs deleted file mode 100644 index 65737797b..000000000 --- a/tests/parallel/parallel_cycle_one_recover.rs +++ /dev/null @@ -1,91 +0,0 @@ -//! Test for cycle recover spread across two threads. -//! See `../cycles.rs` for a complete listing of cycle tests, -//! both intra and cross thread. - -use crate::setup::{Knobs, KnobsDatabase}; - -#[salsa::input] -pub(crate) struct MyInput { - field: i32, -} - -#[salsa::tracked] -pub(crate) fn a1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.signal(1); - db.wait_for(2); - - a2(db, input) -} - -#[salsa::tracked(recovery_fn=recover)] -pub(crate) fn a2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - b1(db, input) -} - -fn recover(db: &dyn KnobsDatabase, _cycle: &salsa::Cycle, key: MyInput) -> i32 { - dbg!("recover"); - key.field(db) * 20 + 2 -} - -#[salsa::tracked] -pub(crate) fn b1(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - // Wait to create the cycle until both threads have entered - db.wait_for(1); - db.signal(2); - - // Wait for thread A to block on this thread - db.wait_for(3); - b2(db, input) -} - -#[salsa::tracked] -pub(crate) fn b2(db: &dyn KnobsDatabase, input: MyInput) -> i32 { - a1(db, input) -} - -// Recover cycle test: -// -// The pattern is as follows. -// -// Thread A Thread B -// -------- -------- -// a1 b1 -// | wait for stage 1 (blocks) -// signal stage 1 | -// wait for stage 2 (blocks) (unblocked) -// | signal stage 2 -// (unblocked) wait for stage 3 (blocks) -// a2 | -// b1 (blocks -> stage 3) | -// | (unblocked) -// | b2 -// | a1 (cycle detected) -// a2 recovery fn executes | -// a1 completes normally | -// b2 completes, recovers -// b1 completes, recovers - -#[test] -fn execute() { - let db = Knobs::default(); - - let input = MyInput::new(&db, 1); - - let thread_a = std::thread::spawn({ - let db = db.clone(); - db.knobs().signal_on_will_block(3); - move || a1(&db, input) - }); - - let thread_b = std::thread::spawn({ - let db = db.clone(); - move || b1(&db, input) - }); - - // We expect that the recovery function yields - // `1 * 20 + 2`, which is returned (and forwarded) - // to b1, and from there to a2 and a1. - assert_eq!(thread_a.join().unwrap(), 22); - assert_eq!(thread_b.join().unwrap(), 22); -} diff --git a/tests/parallel/setup.rs b/tests/parallel/setup.rs index b29d1b7be..52c0ce227 100644 --- a/tests/parallel/setup.rs +++ b/tests/parallel/setup.rs @@ -11,10 +11,10 @@ use crate::signal::Signal; /// a certain behavior. #[salsa::db] pub(crate) trait KnobsDatabase: Database { - fn knobs(&self) -> &Knobs; - + /// Signal that we are entering stage 1. fn signal(&self, stage: usize); + /// Wait until we reach stage `stage` (no-op if we have already reached that stage). fn wait_for(&self, stage: usize); } @@ -80,10 +80,6 @@ impl salsa::Database for Knobs { #[salsa::db] impl KnobsDatabase for Knobs { - fn knobs(&self) -> &Knobs { - self - } - fn signal(&self, stage: usize) { self.signal.signal(stage); }