diff --git a/tokio-util/Cargo.toml b/tokio-util/Cargo.toml index 4187ebfb53e..084d123fa58 100644 --- a/tokio-util/Cargo.toml +++ b/tokio-util/Cargo.toml @@ -45,7 +45,7 @@ slab = { version = "0.4.4", optional = true } # Backs `DelayQueue` tracing = { version = "0.1.29", default-features = false, features = ["std"], optional = true } [target.'cfg(tokio_unstable)'.dependencies] -hashbrown = { version = "0.15.0", default-features = false, features = ["raw-entry"], optional = true } +hashbrown = { version = "0.15.0", default-features = false, optional = true } [dev-dependencies] tokio = { version = "1.0.0", path = "../tokio", features = ["full"] } diff --git a/tokio-util/src/task/join_map.rs b/tokio-util/src/task/join_map.rs index d5c5d2a42c2..02dd46998ac 100644 --- a/tokio-util/src/task/join_map.rs +++ b/tokio-util/src/task/join_map.rs @@ -1,5 +1,5 @@ -use hashbrown::hash_map::RawEntryMut; -use hashbrown::HashMap; +use hashbrown::hash_table::Entry; +use hashbrown::{HashMap, HashTable}; use std::borrow::Borrow; use std::collections::hash_map::RandomState; use std::fmt; @@ -103,13 +103,8 @@ use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet}; #[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))] pub struct JoinMap<K, V, S = RandomState> { /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`, - /// indexed by their keys and task IDs. - /// - /// The [`Key`] type contains both the task's `K`-typed key provided when - /// spawning tasks, and the task's IDs. The IDs are stored here to resolve - /// hash collisions when looking up tasks based on their pre-computed hash - /// (as stored in the `hashes_by_task` map). - tasks_by_key: HashMap<Key<K>, AbortHandle, S>, + /// indexed by their keys. + tasks_by_key: HashTable<(K, AbortHandle)>, /// A map from task IDs to the hash of the key associated with that task. /// @@ -125,21 +120,6 @@ pub struct JoinMap<K, V, S = RandomState> { tasks: JoinSet<V>, } -/// A [`JoinMap`] key. -/// -/// This holds both a `K`-typed key (the actual key as seen by the user), _and_ -/// a task ID, so that hash collisions between `K`-typed keys can be resolved -/// using either `K`'s `Eq` impl *or* by checking the task IDs. -/// -/// This allows looking up a task using either an actual key (such as when the -/// user queries the map with a key), *or* using a task ID and a hash (such as -/// when removing completed tasks from the map). -#[derive(Debug)] -struct Key<K> { - key: K, - id: Id, -} - impl<K, V> JoinMap<K, V> { /// Creates a new empty `JoinMap`. /// @@ -176,7 +156,7 @@ impl<K, V> JoinMap<K, V> { } } -impl<K, V, S: Clone> JoinMap<K, V, S> { +impl<K, V, S> JoinMap<K, V, S> { /// Creates an empty `JoinMap` which will use the given hash builder to hash /// keys. /// @@ -226,7 +206,7 @@ impl<K, V, S: Clone> JoinMap<K, V, S> { #[must_use] pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self { Self { - tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()), + tasks_by_key: HashTable::with_capacity(capacity), hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder), tasks: JoinSet::new(), } @@ -416,26 +396,26 @@ where } fn insert(&mut self, key: K, abort: AbortHandle) { - let hash = self.hash(&key); + let hash_builder = self.hashes_by_task.hasher(); + let hash = hash_one(hash_builder, &key); let id = abort.id(); - let map_key = Key { id, key }; // Insert the new key into the map of tasks by keys. - let entry = self - .tasks_by_key - .raw_entry_mut() - .from_hash(hash, |k| k.key == map_key.key); + let entry = + self.tasks_by_key + .entry(hash, |(k, _)| *k == key, |(k, _)| hash_one(hash_builder, k)); match entry { - RawEntryMut::Occupied(mut occ) => { + Entry::Occupied(occ) => { // There was a previous task spawned with the same key! Cancel // that task, and remove its ID from the map of hashes by task IDs. - let Key { id: prev_id, .. } = occ.insert_key(map_key); - occ.insert(abort).abort(); - let _prev_hash = self.hashes_by_task.remove(&prev_id); + let (_, abort) = std::mem::replace(occ.into_mut(), (key, abort)); + abort.abort(); + + let _prev_hash = self.hashes_by_task.remove(&abort.id()); debug_assert_eq!(Some(hash), _prev_hash); } - RawEntryMut::Vacant(vac) => { - vac.insert(map_key, abort); + Entry::Vacant(vac) => { + vac.insert((key, abort)); } }; @@ -623,7 +603,7 @@ where // Note: this method iterates over the tasks and keys *without* removing // any entries, so that the keys from aborted tasks can still be // returned when calling `join_next` in the future. - for (Key { ref key, .. }, task) in &self.tasks_by_key { + for (key, task) in &self.tasks_by_key { if predicate(key) { task.abort(); } @@ -638,7 +618,7 @@ where /// [`join_next`]: fn@Self::join_next pub fn keys(&self) -> JoinMapKeys<'_, K, V> { JoinMapKeys { - iter: self.tasks_by_key.keys(), + iter: self.tasks_by_key.iter(), _value: PhantomData, } } @@ -666,7 +646,7 @@ where /// [`join_next`]: fn@Self::join_next /// [task ID]: tokio::task::Id pub fn contains_task(&self, task: &Id) -> bool { - self.get_by_id(task).is_some() + self.hashes_by_task.contains_key(task) } /// Reserves capacity for at least `additional` more tasks to be spawned @@ -690,7 +670,9 @@ where /// ``` #[inline] pub fn reserve(&mut self, additional: usize) { - self.tasks_by_key.reserve(additional); + let hash_builder = self.hashes_by_task.hasher(); + self.tasks_by_key + .reserve(additional, |(k, _)| hash_one(hash_builder, k)); self.hashes_by_task.reserve(additional); } @@ -716,7 +698,9 @@ where #[inline] pub fn shrink_to_fit(&mut self) { self.hashes_by_task.shrink_to_fit(); - self.tasks_by_key.shrink_to_fit(); + let hash_builder = self.hashes_by_task.hasher(); + self.tasks_by_key + .shrink_to_fit(|(k, _)| hash_one(hash_builder, k)); } /// Shrinks the capacity of the map with a lower limit. It will drop @@ -745,27 +729,20 @@ where #[inline] pub fn shrink_to(&mut self, min_capacity: usize) { self.hashes_by_task.shrink_to(min_capacity); - self.tasks_by_key.shrink_to(min_capacity) + let hash_builder = self.hashes_by_task.hasher(); + self.tasks_by_key + .shrink_to(min_capacity, |(k, _)| hash_one(hash_builder, k)) } /// Look up a task in the map by its key, returning the key and abort handle. - fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)> + fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<&'map (K, AbortHandle)> where Q: Hash + Eq, K: Borrow<Q>, { - let hash = self.hash(key); - self.tasks_by_key - .raw_entry() - .from_hash(hash, |k| k.key.borrow() == key) - } - - /// Look up a task in the map by its task ID, returning the key and abort handle. - fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)> { - let hash = self.hashes_by_task.get(id)?; - self.tasks_by_key - .raw_entry() - .from_hash(*hash, |k| &k.id == id) + let hash_builder = self.hashes_by_task.hasher(); + let hash = hash_one(hash_builder, key); + self.tasks_by_key.find(hash, |(k, _)| k.borrow() == key) } /// Remove a task from the map by ID, returning the key for that task. @@ -776,28 +753,25 @@ where // Remove the entry for that hash. let entry = self .tasks_by_key - .raw_entry_mut() - .from_hash(hash, |k| k.id == id); - let (Key { id: _key_id, key }, handle) = match entry { - RawEntryMut::Occupied(entry) => entry.remove_entry(), + .find_entry(hash, |(_, abort)| abort.id() == id); + let (key, _) = match entry { + Ok(entry) => entry.remove().0, _ => return None, }; - debug_assert_eq!(_key_id, id); - debug_assert_eq!(id, handle.id()); self.hashes_by_task.remove(&id); Some(key) } +} - /// Returns the hash for a given key. - #[inline] - fn hash<Q: ?Sized>(&self, key: &Q) -> u64 - where - Q: Hash, - { - let mut hasher = self.tasks_by_key.hasher().build_hasher(); - key.hash(&mut hasher); - hasher.finish() - } +/// Returns the hash for a given key. +#[inline] +fn hash_one<S: BuildHasher, Q: ?Sized>(hash_builder: &S, key: &Q) -> u64 +where + Q: Hash, +{ + let mut hasher = hash_builder.build_hasher(); + key.hash(&mut hasher); + hasher.finish() } impl<K, V, S> JoinMap<K, V, S> @@ -831,11 +805,11 @@ impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> { // printing the key and task ID pairs, without format the `Key` struct // itself or the `AbortHandle`, which would just format the task's ID // again. - struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap<Key<K>, AbortHandle, S>); - impl<K: fmt::Debug, S> fmt::Debug for KeySet<'_, K, S> { + struct KeySet<'a, K: fmt::Debug>(&'a HashTable<(K, AbortHandle)>); + impl<K: fmt::Debug> fmt::Debug for KeySet<'_, K> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_map() - .entries(self.0.keys().map(|Key { key, id }| (key, id))) + .entries(self.0.iter().map(|(key, abort)| (key, abort.id()))) .finish() } } @@ -856,31 +830,10 @@ impl<K, V> Default for JoinMap<K, V> { } } -// === impl Key === - -impl<K: Hash> Hash for Key<K> { - // Don't include the task ID in the hash. - #[inline] - fn hash<H: Hasher>(&self, hasher: &mut H) { - self.key.hash(hasher); - } -} - -// Because we override `Hash` for this type, we must also override the -// `PartialEq` impl, so that all instances with the same hash are equal. -impl<K: PartialEq> PartialEq for Key<K> { - #[inline] - fn eq(&self, other: &Self) -> bool { - self.key == other.key - } -} - -impl<K: Eq> Eq for Key<K> {} - /// An iterator over the keys of a [`JoinMap`]. #[derive(Debug, Clone)] pub struct JoinMapKeys<'a, K, V> { - iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>, + iter: hashbrown::hash_table::Iter<'a, (K, AbortHandle)>, /// To make it easier to change `JoinMap` in the future, keep V as a generic /// parameter. _value: PhantomData<&'a V>, @@ -890,7 +843,7 @@ impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> { type Item = &'a K; fn next(&mut self) -> Option<&'a K> { - self.iter.next().map(|key| &key.key) + self.iter.next().map(|(key, _)| key) } fn size_hint(&self) -> (usize, Option<usize>) {