diff --git a/crates/anvil/src/eth/backend/mem/mod.rs b/crates/anvil/src/eth/backend/mem/mod.rs index 2f295ed3311b9..7f64e0f900bb1 100644 --- a/crates/anvil/src/eth/backend/mem/mod.rs +++ b/crates/anvil/src/eth/backend/mem/mod.rs @@ -1,7 +1,7 @@ //! In-memory blockchain backend. use self::state::trie_storage; -use super::executor::new_evm_with_inspector_ref; +use super::{db::StateDb, executor::new_evm_with_inspector_ref}; use crate::{ config::PruneStateHistoryConfig, eth::{ @@ -99,9 +99,9 @@ use foundry_evm::{ }; use futures::channel::mpsc::{unbounded, UnboundedSender}; use op_alloy_consensus::{TxDeposit, DEPOSIT_TX_TYPE_ID}; -use parking_lot::{Mutex, RwLock}; +use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard}; use revm::{ - db::WrapDatabaseRef, + db::{DbAccount, WrapDatabaseRef}, interpreter::Host, primitives::{BlobExcessGasAndPrice, HashMap, OptimismFields, ResultAndState}, DatabaseCommit, @@ -2265,18 +2265,21 @@ impl Backend { .await? .map(|block| (block.header.hash, block)) { - if let Some(state) = self.states.write().get(&block_hash) { - let block = BlockEnv { - number: block_number, - coinbase: block.header.beneficiary, - timestamp: U256::from(block.header.timestamp), - difficulty: block.header.difficulty, - prevrandao: block.header.mix_hash, - basefee: U256::from(block.header.base_fee_per_gas.unwrap_or_default()), - gas_limit: U256::from(block.header.gas_limit), - ..Default::default() - }; - return Ok(f(Box::new(state), block)); + let read_guard = self.states.upgradable_read(); + + if read_guard.has_state(&block_hash) { + let state_db = read_guard.get_state(&block_hash); + + if let Some(state) = state_db { + return Ok(get_block_env(state, block_number, block, f)); + } + } else { + let mut write_guard = RwLockUpgradableReadGuard::upgrade(read_guard); + let state_db = write_guard.get_on_disk_state(&block_hash); + + if let Some(state) = state_db { + return Ok(get_block_env(state, block_number, block, f)); + } } } @@ -2909,12 +2912,18 @@ impl Backend { pub async fn rollback(&self, common_block: Block) -> Result<(), BlockchainError> { // Get the database at the common block let common_state = { - let mut state = self.states.write(); - let state_db = state - .get(&common_block.header.hash_slow()) - .ok_or(BlockchainError::DataUnavailable)?; - let db_full = state_db.maybe_as_full_db().ok_or(BlockchainError::DataUnavailable)?; - db_full.clone() + let hash = &common_block.header.hash_slow(); + let read_guard = self.states.upgradable_read(); + if read_guard.has_state(hash) { + let db = read_guard.get_state(hash); + + return_state_or_throw_err(db).unwrap() + } else { + let write_guard = RwLockUpgradableReadGuard::upgrade(read_guard); + let db = write_guard.get_state(hash); + + return_state_or_throw_err(db).unwrap() + } }; { @@ -2949,6 +2958,35 @@ impl Backend { } } +fn get_block_env(state: &StateDb, block_number: U256, block: AnyRpcBlock, f: F) -> T +where + F: FnOnce(Box, BlockEnv) -> T, +{ + let block = BlockEnv { + number: block_number, + coinbase: block.header.beneficiary, + timestamp: U256::from(block.header.timestamp), + difficulty: block.header.difficulty, + prevrandao: block.header.mix_hash, + basefee: U256::from(block.header.base_fee_per_gas.unwrap_or_default()), + gas_limit: U256::from(block.header.gas_limit), + ..Default::default() + }; + + f(Box::new(state), block) +} + +fn return_state_or_throw_err( + db: Option<&StateDb>, +) -> Result< + HashMap, + BlockchainError, +> { + let state_db = db.ok_or(BlockchainError::DataUnavailable)?; + let db_full = state_db.maybe_as_full_db().ok_or(BlockchainError::DataUnavailable)?; + Ok(db_full.clone()) +} + /// Get max nonce from transaction pool by address fn get_pool_transactions_nonce( pool_transactions: &[Arc], diff --git a/crates/anvil/src/eth/backend/mem/storage.rs b/crates/anvil/src/eth/backend/mem/storage.rs index e024d97505467..a8e6117c57be1 100644 --- a/crates/anvil/src/eth/backend/mem/storage.rs +++ b/crates/anvil/src/eth/backend/mem/storage.rs @@ -173,17 +173,26 @@ impl InMemoryBlockStates { } } - /// Returns the state for the given `hash` if present - pub fn get(&mut self, hash: &B256) -> Option<&StateDb> { - self.states.get(hash).or_else(|| { - if let Some(state) = self.on_disk_states.get_mut(hash) { - if let Some(cached) = self.disk_cache.read(*hash) { - state.init_from_state_snapshot(cached); - return Some(state); - } + /// Checks if hash exists in in-memory state + pub fn has_state(&self, hash: &B256) -> bool { + self.states.contains_key(hash) + } + + /// Returns the in-memory state for the given `hash` if present + pub fn get_state(&self, hash: &B256) -> Option<&StateDb> { + self.states.get(hash) + } + + /// Returns on-disk state for the given `hash` if present + pub fn get_on_disk_state(&mut self, hash: &B256) -> Option<&StateDb> { + if let Some(state) = self.on_disk_states.get_mut(hash) { + if let Some(cached) = self.disk_cache.read(*hash) { + state.init_from_state_snapshot(cached); + return Some(state); } - None - }) + } + + None } /// Sets the maximum number of stats we keep in memory @@ -671,7 +680,7 @@ mod tests { assert_eq!(storage.on_disk_states.len(), 1); assert!(storage.on_disk_states.contains_key(&one)); - let loaded = storage.get(&one).unwrap(); + let loaded = storage.get_on_disk_state(&one).unwrap(); let acc = loaded.basic_ref(addr).unwrap().unwrap(); assert_eq!(acc.balance, U256::from(1337u64)); @@ -696,13 +705,20 @@ mod tests { // wait for files to be flushed tokio::time::sleep(std::time::Duration::from_secs(1)).await; - assert_eq!(storage.on_disk_states.len(), num_states - storage.min_in_memory_limit); + let on_disk_states_len = num_states - storage.min_in_memory_limit; + assert_eq!(storage.on_disk_states.len(), on_disk_states_len); assert_eq!(storage.present.len(), storage.min_in_memory_limit); for idx in 0..num_states { let hash = B256::from(U256::from(idx)); let addr = Address::from_word(hash); - let loaded = storage.get(&hash).unwrap(); + + let loaded = if idx < on_disk_states_len { + storage.get_on_disk_state(&hash).unwrap() + } else { + storage.get_state(&hash).unwrap() + }; + let acc = loaded.basic_ref(addr).unwrap().unwrap(); let balance = (idx * 2) as u64; assert_eq!(acc.balance, U256::from(balance));