Skip to content

Commit 77fdeaf

Browse files
authored
[refactor] Simplify fetch_hot (#792)
* [refactor] Simplify `fetch_hot` * Review feedback * Reduce diff
1 parent 736e66e commit 77fdeaf

File tree

2 files changed

+56
-22
lines changed

2 files changed

+56
-22
lines changed

src/function/fetch.rs

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use crate::function::{Configuration, IngredientImpl, VerifyResult};
44
use crate::runtime::StampedValue;
55
use crate::table::sync::ClaimResult;
66
use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase};
7-
use crate::zalsa_local::QueryRevisions;
8-
use crate::{AsDynDatabase as _, Id};
7+
use crate::zalsa_local::{ActiveQueryGuard, QueryRevisions, ZalsaLocal};
8+
use crate::{AsDynDatabase as _, DatabaseKeyIndex, Id};
99

1010
impl<C> IngredientImpl<C>
1111
where
@@ -87,17 +87,15 @@ where
8787

8888
let shallow_update = self.shallow_verify_memo(zalsa, database_key_index, memo)?;
8989

90-
if self.validate_may_be_provisional(db, zalsa, database_key_index, memo)
91-
|| self.validate_same_iteration(db, database_key_index, memo)
92-
{
90+
if self.validate_may_be_provisional(db, zalsa, database_key_index, memo) {
9391
self.update_shallow(db, zalsa, database_key_index, memo, shallow_update);
9492

9593
// SAFETY: memo is present in memo_map and we have verified that it is
9694
// still valid for the current revision.
97-
return unsafe { Some(self.extend_memo_lifetime(memo)) };
95+
unsafe { Some(self.extend_memo_lifetime(memo)) }
96+
} else {
97+
None
9898
}
99-
100-
None
10199
}
102100

103101
fn fetch_cold<'db>(
@@ -173,15 +171,14 @@ where
173171
ClaimResult::Claimed(guard) => guard,
174172
};
175173

176-
// Push the query on the stack.
177-
let active_query = db.zalsa_local().push_query(database_key_index, 0);
174+
let mut active_query = LazyActiveQueryGuard::new(database_key_index);
178175

179176
// Now that we've claimed the item, check again to see if there's a "hot" value.
180177
let opt_old_memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index);
181178
if let Some(old_memo) = opt_old_memo {
182179
if old_memo.value.is_some() {
183180
if let VerifyResult::Unchanged(_, cycle_heads) =
184-
self.deep_verify_memo(db, zalsa, old_memo, &active_query)
181+
self.deep_verify_memo(db, zalsa, old_memo, &mut active_query)
185182
{
186183
if cycle_heads.is_empty() {
187184
// SAFETY: memo is present in memo_map and we have verified that it is
@@ -192,8 +189,36 @@ where
192189
}
193190
}
194191

195-
let memo = self.execute(db, active_query, opt_old_memo);
192+
let memo = self.execute(db, active_query.into_inner(db.zalsa_local()), opt_old_memo);
196193

197194
Some(memo)
198195
}
199196
}
197+
198+
pub(super) struct LazyActiveQueryGuard<'me> {
199+
guard: Option<ActiveQueryGuard<'me>>,
200+
database_key_index: DatabaseKeyIndex,
201+
}
202+
203+
impl<'me> LazyActiveQueryGuard<'me> {
204+
pub(super) fn new(database_key_index: DatabaseKeyIndex) -> Self {
205+
Self {
206+
guard: None,
207+
database_key_index,
208+
}
209+
}
210+
211+
pub(super) const fn database_key_index(&self) -> DatabaseKeyIndex {
212+
self.database_key_index
213+
}
214+
215+
pub(super) fn guard(&mut self, zalsa_local: &'me ZalsaLocal) -> &ActiveQueryGuard<'me> {
216+
self.guard
217+
.get_or_insert_with(|| zalsa_local.push_query(self.database_key_index, 0))
218+
}
219+
220+
pub(super) fn into_inner(self, zalsa_local: &'me ZalsaLocal) -> ActiveQueryGuard<'me> {
221+
self.guard
222+
.unwrap_or_else(|| zalsa_local.push_query(self.database_key_index, 0))
223+
}
224+
}

src/function/maybe_changed_after.rs

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@ use std::sync::atomic::Ordering;
22

33
use crate::accumulator::accumulated_map::InputAccumulatedValues;
44
use crate::cycle::{CycleHeads, CycleRecoveryStrategy};
5+
use crate::function::fetch::LazyActiveQueryGuard;
56
use crate::function::memo::Memo;
67
use crate::function::{Configuration, IngredientImpl};
78
use crate::key::DatabaseKeyIndex;
89
use crate::table::sync::ClaimResult;
910
use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase};
10-
use crate::zalsa_local::{ActiveQueryGuard, QueryEdge, QueryOrigin};
11+
use crate::zalsa_local::{QueryEdge, QueryOrigin};
1112
use crate::{AsDynDatabase as _, Id, Revision};
1213

1314
/// Result of memo validation.
@@ -135,9 +136,9 @@ where
135136
);
136137

137138
// Check if the inputs are still valid. We can just compare `changed_at`.
138-
let active_query = db.zalsa_local().push_query(database_key_index, 0);
139+
let mut active_query = LazyActiveQueryGuard::new(database_key_index);
139140
if let VerifyResult::Unchanged(_, cycle_heads) =
140-
self.deep_verify_memo(db, zalsa, old_memo, &active_query)
141+
self.deep_verify_memo(db, zalsa, old_memo, &mut active_query)
141142
{
142143
return Some(if old_memo.revisions.changed_at > revision {
143144
VerifyResult::Changed
@@ -151,7 +152,11 @@ where
151152
// backdated. In that case, although we will have computed a new memo,
152153
// the value has not logically changed.
153154
if old_memo.value.is_some() {
154-
let memo = self.execute(db, active_query, Some(old_memo));
155+
let memo = self.execute(
156+
db,
157+
active_query.into_inner(db.zalsa_local()),
158+
Some(old_memo),
159+
);
155160
let changed_at = memo.revisions.changed_at;
156161

157162
return Some(if changed_at > revision {
@@ -317,14 +322,14 @@ where
317322
/// Takes an [`ActiveQueryGuard`] argument because this function recursively
318323
/// walks dependencies of `old_memo` and may even execute them to see if their
319324
/// outputs have changed.
320-
pub(super) fn deep_verify_memo(
325+
pub(super) fn deep_verify_memo<'db>(
321326
&self,
322-
db: &C::DbView,
327+
db: &'db C::DbView,
323328
zalsa: &Zalsa,
324329
old_memo: &Memo<C::Output<'_>>,
325-
active_query: &ActiveQueryGuard<'_>,
330+
active_query: &mut LazyActiveQueryGuard<'db>,
326331
) -> VerifyResult {
327-
let database_key_index = active_query.database_key_index;
332+
let database_key_index = active_query.database_key_index();
328333

329334
tracing::debug!(
330335
"{database_key_index:?}: deep_verify_memo(old_memo = {old_memo:#?})",
@@ -333,9 +338,10 @@ where
333338

334339
let shallow_update = self.shallow_verify_memo(zalsa, database_key_index, old_memo);
335340
let shallow_update_possible = shallow_update.is_some();
336-
337341
if let Some(shallow_update) = shallow_update {
338-
if self.validate_may_be_provisional(db, zalsa, database_key_index, old_memo) {
342+
if self.validate_may_be_provisional(db, zalsa, database_key_index, old_memo)
343+
|| self.validate_same_iteration(db, database_key_index, old_memo)
344+
{
339345
self.update_shallow(db, zalsa, database_key_index, old_memo, shallow_update);
340346

341347
return VerifyResult::unchanged();
@@ -365,11 +371,14 @@ where
365371
}
366372
QueryOrigin::Derived(edges) => {
367373
let is_provisional = old_memo.may_be_provisional();
374+
368375
// If the value is from the same revision but is still provisional, consider it changed
369376
if shallow_update_possible && is_provisional {
370377
return VerifyResult::Changed;
371378
}
372379

380+
let _guard = active_query.guard(db.zalsa_local());
381+
373382
let mut cycle_heads = CycleHeads::default();
374383
'cycle: loop {
375384
// Fully tracked inputs? Iterate over the inputs and check them, one by one.

0 commit comments

Comments
 (0)