Skip to content

Commit f44f2f7

Browse files
committed
pass inputs to cycle recovery functions
1 parent 84f5eab commit f44f2f7

File tree

8 files changed

+55
-20
lines changed

8 files changed

+55
-20
lines changed

components/salsa-macro-rules/src/setup_tracked_fn.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,17 @@ macro_rules! setup_tracked_fn {
177177
$inner($db, $($input_id),*)
178178
}
179179

180-
fn cycle_initial<$db_lt>(db: &$db_lt dyn $Db) -> Self::Output<$db_lt> {
181-
$($cycle_recovery_initial)*(db)
180+
fn cycle_initial<$db_lt>(db: &$db_lt dyn $Db, ($($input_id),*): ($($input_ty),*)) -> Self::Output<$db_lt> {
181+
$($cycle_recovery_initial)*(db, $($input_id),*)
182182
}
183183

184184
fn recover_from_cycle<$db_lt>(
185185
db: &$db_lt dyn $Db,
186186
value: &Self::Output<$db_lt>,
187187
count: u32,
188+
($($input_id),*): ($($input_ty),*)
188189
) -> $zalsa::CycleRecoveryAction<Self::Output<$db_lt>> {
189-
$($cycle_recovery_fn)*(db, value, count)
190+
$($cycle_recovery_fn)*(db, value, count, $($input_id),*)
190191
}
191192

192193
fn id_to_input<$db_lt>(db: &$db_lt Self::DbView, key: salsa::Id) -> Self::Input<$db_lt> {
Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
// Macro that generates the body of the cycle recovery function
2-
// for the case where no cycle recovery is possible. Must be a macro
3-
// because the signature types must match the particular tracked function.
2+
// for the case where no cycle recovery is possible. This has to be
3+
// a macro because it can take a variadic number of arguments.
44
#[macro_export]
55
macro_rules! unexpected_cycle_recovery {
6-
($db:ident, $value:ident, $count:ident) => {{
6+
($db:ident, $value:ident, $count:ident, $($other_inputs:ident),*) => {{
77
std::mem::drop($db);
8+
std::mem::drop(($($other_inputs),*));
89
panic!("cannot recover from cycle")
910
}};
1011
}
1112

1213
#[macro_export]
1314
macro_rules! unexpected_cycle_initial {
14-
($db:ident) => {{
15+
($db:ident, $($other_inputs:ident),*) => {{
1516
std::mem::drop($db);
17+
std::mem::drop(($($other_inputs),*));
1618
panic!("no cycle initial value")
1719
}};
1820
}

src/function.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,14 @@ pub trait Configuration: Any {
6767
fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>;
6868

6969
/// Get the cycle recovery initial value.
70-
fn cycle_initial(db: &Self::DbView) -> Self::Output<'_>;
70+
fn cycle_initial<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>;
7171

7272
/// Decide whether to iterate a cycle again or fallback.
7373
fn recover_from_cycle<'db>(
7474
db: &'db Self::DbView,
7575
value: &Self::Output<'db>,
7676
count: u32,
77+
input: Self::Input<'db>,
7778
) -> CycleRecoveryAction<Self::Output<'db>>;
7879
}
7980

src/function/execute.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ where
9090
if !C::values_equal(&new_value, last_provisional_value) {
9191
// We are in a cycle that hasn't converged; ask the user's
9292
// cycle-recovery function what to do:
93-
match C::recover_from_cycle(db, &new_value, iteration_count) {
93+
match C::recover_from_cycle(
94+
db,
95+
&new_value,
96+
iteration_count,
97+
C::id_to_input(db, id),
98+
) {
9499
crate::CycleRecoveryAction::Iterate => {
95100
tracing::debug!("{database_key_index:?}: execute: iterate again");
96101
iteration_count = iteration_count.checked_add(1).expect(

src/function/fetch.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ where
7878
ClaimResult::Retry => return None,
7979
ClaimResult::Cycle => {
8080
return self
81-
.initial_value(db)
81+
.initial_value(db, database_key_index.key_index)
8282
.map(|initial_value| {
8383
tracing::debug!(
8484
"hit cycle at {database_key_index:#?}, \

src/function/memo.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,13 @@ impl<C: Configuration> IngredientImpl<C> {
8888
}
8989
}
9090

91-
pub(super) fn initial_value<'db>(&'db self, db: &'db C::DbView) -> Option<C::Output<'db>> {
91+
pub(super) fn initial_value<'db>(
92+
&'db self,
93+
db: &'db C::DbView,
94+
key: Id,
95+
) -> Option<C::Output<'db>> {
9296
match C::CYCLE_STRATEGY {
93-
CycleRecoveryStrategy::Fixpoint => Some(C::cycle_initial(db)),
97+
CycleRecoveryStrategy::Fixpoint => Some(C::cycle_initial(db, C::id_to_input(db, key))),
9498
CycleRecoveryStrategy::Panic => None,
9599
}
96100
}

tests/cycle/dataflow.rs

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ impl Type {
4747
}
4848
}
4949

50-
#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=cycle_initial)]
50+
#[salsa::tracked(cycle_fn=use_cycle_recover, cycle_initial=use_cycle_initial)]
5151
fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type {
5252
let defs = u.reaching_definitions(db);
5353
match defs[..] {
@@ -57,7 +57,7 @@ fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type {
5757
}
5858
}
5959

60-
#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=cycle_initial)]
60+
#[salsa::tracked(cycle_fn=def_cycle_recover, cycle_initial=def_cycle_initial)]
6161
fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type {
6262
let increment_ty = Type::Values(Box::from([def.increment(db)]));
6363
if let Some(base) = def.base(db) {
@@ -68,11 +68,33 @@ fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type {
6868
}
6969
}
7070

71-
fn cycle_initial(_db: &dyn Db) -> Type {
71+
fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type {
7272
Type::Bottom
7373
}
7474

75-
fn cycle_recover(_db: &dyn Db, value: &Type, count: u32) -> CycleRecoveryAction<Type> {
75+
fn def_cycle_recover(
76+
_db: &dyn Db,
77+
value: &Type,
78+
count: u32,
79+
_def: Definition,
80+
) -> CycleRecoveryAction<Type> {
81+
cycle_recover(value, count)
82+
}
83+
84+
fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type {
85+
Type::Bottom
86+
}
87+
88+
fn use_cycle_recover(
89+
_db: &dyn Db,
90+
value: &Type,
91+
count: u32,
92+
_use: Use,
93+
) -> CycleRecoveryAction<Type> {
94+
cycle_recover(value, count)
95+
}
96+
97+
fn cycle_recover(value: &Type, count: u32) -> CycleRecoveryAction<Type> {
7698
match value {
7799
Type::Bottom => CycleRecoveryAction::Iterate,
78100
Type::Values(_) => {

tests/cycle/main.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ const MIN_COUNT_FALLBACK: u8 = 100;
7676
const MIN_VALUE_FALLBACK: u8 = 5;
7777
const MIN_VALUE: u8 = 10;
7878

79-
fn min_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction<u8> {
79+
fn min_recover(_db: &dyn Db, value: &u8, count: u32, _inputs: Inputs) -> CycleRecoveryAction<u8> {
8080
if *value < MIN_VALUE {
8181
CycleRecoveryAction::Fallback(MIN_VALUE_FALLBACK)
8282
} else if count > 10 {
@@ -86,7 +86,7 @@ fn min_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction<u8>
8686
}
8787
}
8888

89-
fn min_initial(_db: &dyn Db) -> u8 {
89+
fn min_initial(_db: &dyn Db, _inputs: Inputs) -> u8 {
9090
255
9191
}
9292

@@ -99,7 +99,7 @@ const MAX_COUNT_FALLBACK: u8 = 200;
9999
const MAX_VALUE_FALLBACK: u8 = 250;
100100
const MAX_VALUE: u8 = 245;
101101

102-
fn max_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction<u8> {
102+
fn max_recover(_db: &dyn Db, value: &u8, count: u32, _inputs: Inputs) -> CycleRecoveryAction<u8> {
103103
if *value > MAX_VALUE {
104104
CycleRecoveryAction::Fallback(MAX_VALUE_FALLBACK)
105105
} else if count > 10 {
@@ -109,7 +109,7 @@ fn max_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction<u8>
109109
}
110110
}
111111

112-
fn max_initial(_db: &dyn Db) -> u8 {
112+
fn max_initial(_db: &dyn Db, _inputs: Inputs) -> u8 {
113113
0
114114
}
115115

0 commit comments

Comments
 (0)