diff --git a/src/cycle.rs b/src/cycle.rs index 3cac79e3d..e1c63d653 100644 --- a/src/cycle.rs +++ b/src/cycle.rs @@ -132,10 +132,6 @@ impl CycleHeads { true } - pub(crate) fn clear(&mut self) { - self.0.clear(); - } - pub(crate) fn update_iteration_count( &mut self, cycle_head_index: DatabaseKeyIndex, diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 6df5b0372..d315ca07a 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -79,19 +79,24 @@ where id: Id, memo_ingredient_index: MemoIngredientIndex, ) -> Option<&'db Memo>> { - let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); - if let Some(memo) = memo_guard { - let database_key_index = self.database_key_index(id); - if memo.value.is_some() - && (self.validate_may_be_provisional(db, zalsa, database_key_index, memo) - || self.validate_same_iteration(db, database_key_index, memo)) - && self.shallow_verify_memo(db, zalsa, database_key_index, memo) - { - // SAFETY: memo is present in memo_map and we have verified that it is - // still valid for the current revision. - return unsafe { Some(self.extend_memo_lifetime(memo)) }; - } + let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index)?; + + memo.value.as_ref()?; + + let database_key_index = self.database_key_index(id); + + let shallow_update = self.shallow_verify_memo(zalsa, database_key_index, memo)?; + + if self.validate_may_be_provisional(db, zalsa, database_key_index, memo) + || self.validate_same_iteration(db, database_key_index, memo) + { + self.update_shallow(db, zalsa, database_key_index, memo, shallow_update); + + // SAFETY: memo is present in memo_map and we have verified that it is + // still valid for the current revision. + return unsafe { Some(self.extend_memo_lifetime(memo)) }; } + None } @@ -120,10 +125,20 @@ where if let Some(memo) = memo_guard { if memo.value.is_some() && memo.revisions.cycle_heads.contains(&database_key_index) - && self.shallow_verify_memo(db, zalsa, database_key_index, memo) { - // SAFETY: memo is present in memo_map. - return unsafe { Some(self.extend_memo_lifetime(memo)) }; + if let Some(shallow_update) = + self.shallow_verify_memo(zalsa, database_key_index, memo) + { + self.update_shallow( + db, + zalsa, + database_key_index, + memo, + shallow_update, + ); + // SAFETY: memo is present in memo_map. + return unsafe { Some(self.extend_memo_lifetime(memo)) }; + } } } // no provisional value; create/insert/return initial provisional value diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 5ed840463..d9a299aec 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -60,10 +60,16 @@ where // Check if we have a verified version: this is the hot path. let memo_guard = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); - if let Some(memo) = memo_guard { - if self.validate_may_be_provisional(db, zalsa, database_key_index, memo) - && self.shallow_verify_memo(db, zalsa, database_key_index, memo) - { + let Some(memo) = memo_guard else { + // No memo? Assume has changed. + return VerifyResult::Changed; + }; + + if let Some(shallow_update) = self.shallow_verify_memo(zalsa, database_key_index, memo) + { + if self.validate_provisional(db, zalsa, database_key_index, memo) { + self.update_shallow(db, zalsa, database_key_index, memo, shallow_update); + return if memo.revisions.changed_at > revision { VerifyResult::Changed } else { @@ -73,16 +79,14 @@ where ) }; } - if let Some(mcs) = - self.maybe_changed_after_cold(zalsa, db, id, revision, memo_ingredient_index) - { - return mcs; - } else { - // We failed to claim, have to retry. - } + } + + if let Some(mcs) = + self.maybe_changed_after_cold(zalsa, db, id, revision, memo_ingredient_index) + { + return mcs; } else { - // No memo? Assume has changed. - return VerifyResult::Changed; + // We failed to claim, have to retry. } } } @@ -167,7 +171,7 @@ where Some(VerifyResult::Changed) } - /// True if the memo's value and `changed_at` time is still valid in this revision. + /// `Some` if the memo's value and `changed_at` time is still valid in this revision. /// Does only a shallow O(1) check, doesn't walk the dependencies. /// /// In general, a provisional memo (from cycle iteration) does not verify. Since we don't @@ -177,11 +181,10 @@ where #[inline] pub(super) fn shallow_verify_memo( &self, - db: &C::DbView, zalsa: &Zalsa, database_key_index: DatabaseKeyIndex, memo: &Memo>, - ) -> bool { + ) -> Option { tracing::debug!( "{database_key_index:?}: shallow_verify_memo(memo = {memo:#?})", memo = memo.tracing_debug() @@ -191,7 +194,7 @@ where if verified_at == revision_now { // Already verified. - return true; + return Some(ShallowUpdate::Verified); } let last_changed = zalsa.last_changed_revision(memo.revisions.durability); @@ -204,17 +207,31 @@ where ); if last_changed <= verified_at { // No input of the suitable durability has changed since last verified. + Some(ShallowUpdate::HigherDurability(revision_now)) + } else { + None + } + } + + #[inline] + pub(super) fn update_shallow( + &self, + db: &C::DbView, + zalsa: &Zalsa, + database_key_index: DatabaseKeyIndex, + memo: &Memo>, + update: ShallowUpdate, + ) { + if let ShallowUpdate::HigherDurability(revision_now) = update { memo.mark_as_verified( db, revision_now, database_key_index, memo.revisions.accumulated_inputs.load(), ); + memo.mark_outputs_as_verified(zalsa, db.as_dyn_database(), database_key_index); - return true; } - - false } /// Validates this memo if it is a provisional memo. Returns true for non provisional memos or @@ -311,10 +328,15 @@ where old_memo = old_memo.tracing_debug() ); - if self.validate_may_be_provisional(db, zalsa, database_key_index, old_memo) - && self.shallow_verify_memo(db, zalsa, database_key_index, old_memo) - { - return VerifyResult::unchanged(); + let shallow_update = self.shallow_verify_memo(zalsa, database_key_index, old_memo); + let shallow_update_possible = shallow_update.is_some(); + + if let Some(shallow_update) = shallow_update { + if self.validate_provisional(db, zalsa, database_key_index, old_memo) { + self.update_shallow(db, zalsa, database_key_index, old_memo, shallow_update); + + return VerifyResult::unchanged(); + } } match &old_memo.revisions.origin { @@ -339,7 +361,9 @@ where VerifyResult::Changed } QueryOrigin::Derived(edges) => { - if old_memo.may_be_provisional() { + let is_provisional = old_memo.may_be_provisional(); + // If the value is from the same revision but is still provisional, consider it changed + if shallow_update_possible && is_provisional { return VerifyResult::Changed; } @@ -428,15 +452,18 @@ where inputs, ); + if is_provisional { + old_memo + .revisions + .verified_final + .store(true, Ordering::Relaxed); + } + if in_heads { - // Iterate our dependency graph again, starting from the top. We clear the - // cycle heads here because we are starting a fresh traversal. (It might be - // logically clearer to create a new HashSet each time, but clearing the - // existing one is more efficient.) - cycle_heads.clear(); continue 'cycle; } } + break 'cycle VerifyResult::Unchanged( InputAccumulatedValues::Empty, cycle_heads, @@ -446,3 +473,13 @@ where } } } + +#[derive(Copy, Clone, Eq, PartialEq)] +pub(super) enum ShallowUpdate { + /// The memo is from this revision and has already been verified + Verified, + + /// The revision for the memo's durability hasn't changed. It can be marked as verified + /// in this revision. + HigherDurability(Revision), +} diff --git a/tests/cycle_tracked.rs b/tests/cycle_tracked.rs new file mode 100644 index 000000000..11abd27ec --- /dev/null +++ b/tests/cycle_tracked.rs @@ -0,0 +1,191 @@ +//! Tests for cycles where the cycle head is stored on a tracked struct +//! and that tracked struct is freed in a later revision. + +mod common; + +use crate::common::{EventLoggerDatabase, LogDatabase}; +use expect_test::expect; +use salsa::{CycleRecoveryAction, Database, Setter}; + +#[derive(Clone, Debug, Eq, PartialEq, Hash, salsa::Update)] +struct Graph<'db> { + nodes: Vec>, +} + +impl<'db> Graph<'db> { + fn find_node(&self, db: &dyn salsa::Database, name: &str) -> Option> { + self.nodes + .iter() + .find(|node| node.name(db) == name) + .copied() + } +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +struct Edge { + // Index into `graph.nodes` + to: usize, + cost: usize, +} + +#[salsa::tracked(debug)] +struct Node<'db> { + #[return_ref] + name: String, + + #[return_ref] + #[tracked] + edges: Vec, + + graph: GraphInput, +} + +#[salsa::input(debug)] +struct GraphInput { + simple: bool, +} + +#[salsa::tracked(return_ref)] +fn create_graph(db: &dyn salsa::Database, input: GraphInput) -> Graph<'_> { + if input.simple(db) { + let a = Node::new(db, "a".to_string(), vec![], input); + let b = Node::new(db, "b".to_string(), vec![Edge { to: 0, cost: 20 }], input); + let c = Node::new(db, "c".to_string(), vec![Edge { to: 1, cost: 2 }], input); + + Graph { + nodes: vec![a, b, c], + } + } else { + // ``` + // flowchart TD + // + // A("a") + // B("b") + // C("c") + // D{"d"} + // + // B -- 20 --> D + // C -- 4 --> D + // D -- 4 --> A + // D -- 4 --> B + // ``` + let a = Node::new(db, "a".to_string(), vec![], input); + let b = Node::new(db, "b".to_string(), vec![Edge { to: 3, cost: 20 }], input); + let c = Node::new(db, "c".to_string(), vec![Edge { to: 3, cost: 4 }], input); + let d = Node::new( + db, + "d".to_string(), + vec![Edge { to: 0, cost: 4 }, Edge { to: 1, cost: 4 }], + input, + ); + + Graph { + nodes: vec![a, b, c, d], + } + } +} + +/// Computes the minimum cost from the node with offset `0` to the given node. +#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=max_initial)] +fn cost_to_start<'db>(db: &'db dyn Database, node: Node<'db>) -> usize { + let mut min_cost = usize::MAX; + let graph = create_graph(db, node.graph(db)); + + for edge in node.edges(db) { + if edge.to == 0 { + min_cost = min_cost.min(edge.cost); + } + + let edge_cost_to_start = cost_to_start(db, graph.nodes[edge.to]); + + // We hit a cycle, never take this edge because it will always be more expensive than + // any other edge + if edge_cost_to_start == usize::MAX { + continue; + } + + min_cost = min_cost.min(edge.cost + edge_cost_to_start); + } + + min_cost +} + +fn max_initial(_db: &dyn Database, _node: Node) -> usize { + usize::MAX +} + +fn cycle_recover( + _db: &dyn Database, + _value: &usize, + _count: u32, + _inputs: Node, +) -> CycleRecoveryAction { + CycleRecoveryAction::Iterate +} + +#[test] +fn main() { + let mut db = EventLoggerDatabase::default(); + + let input = GraphInput::new(&db, false); + let graph = create_graph(&db, input); + let c = graph.find_node(&db, "c").unwrap(); + + // Query the cost from `c` to `a`. + // There's a cycle between `b` and `d`, where `d` becomes the cycle head and `b` is a provisional, non finalized result. + assert_eq!(cost_to_start(&db, c), 8); + + // Change the graph, this will remove `d`, leaving `b` pointing to a cycle head that's now collected. + // Querying the cost from `c` to `a` should try to verify the result of `b` and it is important + // that `b` doesn't try to dereference the cycle head (because its memo is now stored on a tracked + // struct that has been freed). + input.set_simple(&mut db).to(true); + + let graph = create_graph(&db, input); + let c = graph.find_node(&db, "c").unwrap(); + + assert_eq!(cost_to_start(&db, c), 22); + + db.assert_logs(expect![[r#" + [ + "WillCheckCancellation", + "WillExecute { database_key: create_graph(Id(0)) }", + "WillCheckCancellation", + "WillExecute { database_key: cost_to_start(Id(402)) }", + "WillCheckCancellation", + "WillCheckCancellation", + "WillExecute { database_key: cost_to_start(Id(403)) }", + "WillCheckCancellation", + "WillCheckCancellation", + "WillExecute { database_key: cost_to_start(Id(400)) }", + "WillCheckCancellation", + "WillCheckCancellation", + "WillExecute { database_key: cost_to_start(Id(401)) }", + "WillCheckCancellation", + "WillCheckCancellation", + "WillCheckCancellation", + "WillCheckCancellation", + "WillCheckCancellation", + "WillExecute { database_key: cost_to_start(Id(401)) }", + "WillCheckCancellation", + "WillCheckCancellation", + "DidSetCancellationFlag", + "WillCheckCancellation", + "WillExecute { database_key: create_graph(Id(0)) }", + "WillDiscardStaleOutput { execute_key: create_graph(Id(0)), output_key: Node(Id(403)) }", + "DidDiscard { key: Node(Id(403)) }", + "DidDiscard { key: cost_to_start(Id(403)) }", + "WillCheckCancellation", + "WillCheckCancellation", + "WillExecute { database_key: cost_to_start(Id(402)) }", + "WillCheckCancellation", + "WillCheckCancellation", + "WillCheckCancellation", + "WillExecute { database_key: cost_to_start(Id(401)) }", + "WillCheckCancellation", + "WillCheckCancellation", + "WillCheckCancellation", + "WillExecute { database_key: cost_to_start(Id(400)) }", + "WillCheckCancellation", + ]"#]]); +}