Skip to content

Commit 5a435a7

Browse files
committed
bug: Fix missing cycle inputs
1 parent 42f1583 commit 5a435a7

9 files changed

+139
-25
lines changed

src/active_query.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,20 @@ pub(crate) struct ActiveQuery {
6666
}
6767

6868
impl ActiveQuery {
69+
pub(super) fn seed_iteration(
70+
&mut self,
71+
durability: Durability,
72+
changed_at: Revision,
73+
edges: &[QueryEdge],
74+
untracked_read: bool,
75+
) {
76+
assert!(self.input_outputs.is_empty());
77+
self.input_outputs = edges.iter().cloned().collect();
78+
self.durability = self.durability.min(durability);
79+
self.changed_at = self.changed_at.max(changed_at);
80+
self.untracked_read = untracked_read;
81+
}
82+
6983
pub(super) fn add_read(
7084
&mut self,
7185
input: DatabaseKeyIndex,

src/function/backdate.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::function::memo::Memo;
22
use crate::function::{Configuration, IngredientImpl};
33
use crate::zalsa_local::QueryRevisions;
4+
use crate::DatabaseKeyIndex;
45

56
impl<C> IngredientImpl<C>
67
where
@@ -12,6 +13,7 @@ where
1213
pub(super) fn backdate_if_appropriate<'db>(
1314
&self,
1415
old_memo: &Memo<C::Output<'db>>,
16+
index: DatabaseKeyIndex,
1517
revisions: &mut QueryRevisions,
1618
value: &C::Output<'db>,
1719
) {
@@ -24,7 +26,7 @@ where
2426
&& C::values_equal(old_value, value)
2527
{
2628
tracing::debug!(
27-
"value is equal, back-dating to {:?}",
29+
"{index:?} value is equal, back-dating to {:?}",
2830
old_memo.revisions.changed_at,
2931
);
3032

src/function/execute.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ where
101101
// really change, even if some of its inputs have. So we can
102102
// "backdate" its `changed_at` revision to be the same as the
103103
// old value.
104-
self.backdate_if_appropriate(old_memo, &mut revisions, &new_value);
104+
self.backdate_if_appropriate(old_memo, database_key_index, &mut revisions, &new_value);
105105

106106
// Diff the new outputs with the old, to discard any no-longer-emitted
107107
// outputs and update the tracked struct IDs for seeding the next revision.
@@ -115,12 +115,20 @@ where
115115
provisional,
116116
);
117117
}
118-
self.insert_memo(
118+
let memo = self.insert_memo(
119119
zalsa,
120120
id,
121121
Memo::new(Some(new_value), zalsa.current_revision(), revisions),
122122
memo_ingredient_index,
123-
)
123+
);
124+
125+
tracing::info!(
126+
"{:?}: executed query {:?}",
127+
database_key_index,
128+
memo.tracing_debug()
129+
);
130+
131+
memo
124132
}
125133

126134
#[inline]
@@ -255,27 +263,25 @@ where
255263
current_revision: Revision,
256264
id: Id,
257265
) -> (C::Output<'db>, QueryRevisions) {
258-
// If we already executed this query once, then use the tracked-struct ids from the
259-
// previous execution as the starting point for the new one.
260266
if let Some(old_memo) = opt_old_memo {
267+
// If we already executed this query once, then use the tracked-struct ids from the
268+
// previous execution as the starting point for the new one.
261269
active_query.seed_tracked_struct_ids(&old_memo.revisions.tracked_struct_ids);
262-
}
263-
264-
// Query was not previously executed, or value is potentially
265-
// stale, or value is absent. Let's execute!
266-
let new_value = C::execute(db, C::id_to_input(db, id));
267270

268-
if let Some(old_memo) = opt_old_memo {
269271
// Copy over all outputs from a previous iteration.
270272
// This is necessary to ensure that tracked struct created during the previous iteration
271273
// (and are owned by the query) are alive even if the query in this iteration no longer creates them.
272274
// The query not re-creating the tracked struct doesn't guarantee that there
273275
// aren't any other queries depending on it.
274276
if old_memo.may_be_provisional() && old_memo.verified_at.load() == current_revision {
275-
active_query.append_outputs(old_memo.revisions.origin.outputs());
277+
active_query.seed_iteration(&old_memo.revisions);
276278
}
277279
}
278280

281+
// Query was not previously executed, or value is potentially
282+
// stale, or value is absent. Let's execute!
283+
let new_value = C::execute(db, C::id_to_input(db, id));
284+
279285
(new_value, active_query.pop())
280286
}
281287
}

src/function/fetch.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::function::sync::ClaimResult;
44
use crate::function::{Configuration, IngredientImpl, VerifyResult};
55
use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase};
66
use crate::zalsa_local::QueryRevisions;
7-
use crate::Id;
7+
use crate::{Durability, Id};
88

99
impl<C> IngredientImpl<C>
1010
where
@@ -20,6 +20,12 @@ where
2020

2121
self.lru.record_use(id);
2222

23+
tracing::debug!(
24+
"now: {:?}, medium: {:?}",
25+
zalsa.current_revision(),
26+
zalsa.last_changed_revision(Durability::MEDIUM)
27+
);
28+
2329
zalsa_local.report_tracked_read(
2430
self.database_key_index(id),
2531
memo.revisions.durability,

src/function/maybe_changed_after.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ where
188188
memo: &Memo<C::Output<'_>>,
189189
) -> Option<ShallowUpdate> {
190190
tracing::debug!(
191-
"{database_key_index:?}: shallow_verify_memo(memo = {memo:#?})",
191+
"{database_key_index:?}: shallow_verify_memo(memo = {memo:?})",
192192
memo = memo.tracing_debug()
193193
);
194194
let verified_at = memo.verified_at.load();

src/function/memo.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,10 @@ impl<V> Memo<V> {
294294
},
295295
)
296296
.field("verified_at", &self.memo.verified_at)
297-
.field("revisions", &self.memo.revisions)
297+
.field("changed_at", &self.memo.revisions.changed_at)
298+
.field("durability", &self.memo.revisions.durability)
299+
.field("cycle_heads", &self.memo.revisions.cycle_heads)
300+
// .field("revisions", &self.memo.revisions)
298301
.finish()
299302
}
300303
}

src/function/specify.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ where
7676

7777
let memo_ingredient_index = self.memo_ingredient_index(zalsa, key);
7878
if let Some(old_memo) = self.get_memo_from_table_for(zalsa, key, memo_ingredient_index) {
79-
self.backdate_if_appropriate(old_memo, &mut revisions, &value);
79+
self.backdate_if_appropriate(old_memo, database_key_index, &mut revisions, &value);
8080
self.diff_outputs(
8181
zalsa,
8282
db,

src/zalsa_local.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ pub(crate) struct QueryRevisions {
372372
impl QueryRevisions {
373373
pub(crate) fn fixpoint_initial(query: DatabaseKeyIndex, revision: Revision) -> Self {
374374
Self {
375+
// TODO: I think it would be fine to use Revision::start here
375376
changed_at: revision,
376377
durability: Durability::MAX,
377378
origin: QueryOrigin::FixpointInitial,
@@ -424,6 +425,16 @@ impl QueryOrigin {
424425
};
425426
opt_edges.into_iter().flat_map(|edges| edges.outputs())
426427
}
428+
429+
pub(crate) fn edges(&self) -> &[QueryEdge] {
430+
let opt_edges = match self {
431+
QueryOrigin::Derived(edges) | QueryOrigin::DerivedUntracked(edges) => Some(edges),
432+
QueryOrigin::Assigned(_) | QueryOrigin::FixpointInitial => None,
433+
};
434+
opt_edges
435+
.map(|edges| &*edges.input_outputs)
436+
.unwrap_or_default()
437+
}
427438
}
428439

429440
/// The edges between a memoized value and other queries in the dependency graph.
@@ -508,18 +519,17 @@ impl ActiveQueryGuard<'_> {
508519
}
509520

510521
/// Append the given `outputs` to the query's output list.
511-
pub(crate) fn append_outputs<I>(&self, outputs: I)
512-
where
513-
I: IntoIterator<Item = DatabaseKeyIndex> + UnwindSafe,
514-
{
522+
pub(crate) fn seed_iteration(&self, previous: &QueryRevisions) {
523+
let durability = previous.durability;
524+
let changed_at = previous.changed_at;
525+
let edges = previous.origin.edges();
526+
let untracked_read = matches!(previous.origin, QueryOrigin::DerivedUntracked(_));
527+
515528
self.local_state.with_query_stack_mut(|stack| {
516529
#[cfg(debug_assertions)]
517530
assert_eq!(stack.len(), self.push_len);
518531
let frame = stack.last_mut().unwrap();
519-
520-
for output in outputs {
521-
frame.add_output(output);
522-
}
532+
frame.seed_iteration(durability, changed_at, edges, untracked_read);
523533
})
524534
}
525535

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
//! Test for cycle where only the first iteration of a query depends on the input value.
2+
mod common;
3+
4+
use crate::common::EventLoggerDatabase;
5+
use salsa::{CycleRecoveryAction, Database, Durability, Setter};
6+
7+
#[salsa::input(debug)]
8+
struct Input {
9+
value: u32,
10+
max: u32,
11+
}
12+
13+
#[salsa::interned(debug)]
14+
struct Output<'db> {
15+
#[return_ref]
16+
value: u32,
17+
}
18+
19+
#[salsa::tracked(cycle_fn=query_a_recover, cycle_initial=query_a_initial)]
20+
fn query_a<'db>(db: &'db dyn salsa::Database, input: Input) -> u32 {
21+
query_b(db, input)
22+
}
23+
24+
// Query b also gets low durability because of query_a. How can we avoid that?
25+
// Or is the bug that we loose the durability somehow?
26+
#[salsa::tracked]
27+
fn query_b<'db>(db: &'db dyn salsa::Database, input: Input) -> u32 {
28+
let value = query_a(db, input);
29+
30+
if value < input.max(db) {
31+
// Only the first iteration depends on value but the entire
32+
// cycle must re-run if input changes.
33+
let result = value + input.value(db);
34+
Output::new(db, result);
35+
result
36+
} else {
37+
value
38+
}
39+
}
40+
41+
// Note: Also requires same output or backdating won't happen. but other query output needs to be different at least once to fixpint
42+
fn query_a_initial<'db>(db: &'db dyn Database, input: Input) -> u32 {
43+
0
44+
}
45+
46+
fn query_a_recover<'db>(
47+
_db: &'db dyn Database,
48+
_output: &u32,
49+
_count: u32,
50+
_input: Input,
51+
) -> CycleRecoveryAction<u32> {
52+
CycleRecoveryAction::Iterate
53+
}
54+
55+
#[test_log::test]
56+
fn main() {
57+
let mut db = EventLoggerDatabase::default();
58+
59+
let input = Input::builder(4, 5).durability(Durability::MEDIUM).new(&db);
60+
61+
{
62+
let result = query_a(&db, input);
63+
64+
assert_eq!(result, 8);
65+
}
66+
67+
{
68+
input.set_value(&mut db).to(3);
69+
70+
let result = query_a(&db, input);
71+
assert_eq!(result, 6);
72+
}
73+
}

0 commit comments

Comments
 (0)