Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(tokio-util JoinMap) Remove raw-entry feature in favour of HashTable API. #7252

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tokio-util/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ slab = { version = "0.4.4", optional = true } # Backs `DelayQueue`
tracing = { version = "0.1.29", default-features = false, features = ["std"], optional = true }

[target.'cfg(tokio_unstable)'.dependencies]
hashbrown = { version = "0.15.0", default-features = false, features = ["raw-entry"], optional = true }
hashbrown = { version = "0.15.0", default-features = false, optional = true }

[dev-dependencies]
tokio = { version = "1.0.0", path = "../tokio", features = ["full"] }
Expand Down
151 changes: 52 additions & 99 deletions tokio-util/src/task/join_map.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use hashbrown::hash_map::RawEntryMut;
use hashbrown::HashMap;
use hashbrown::hash_table::Entry;
use hashbrown::{HashMap, HashTable};
use std::borrow::Borrow;
use std::collections::hash_map::RandomState;
use std::fmt;
Expand Down Expand Up @@ -103,13 +103,8 @@ use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))]
pub struct JoinMap<K, V, S = RandomState> {
/// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`,
/// indexed by their keys and task IDs.
///
/// The [`Key`] type contains both the task's `K`-typed key provided when
/// spawning tasks, and the task's IDs. The IDs are stored here to resolve
/// hash collisions when looking up tasks based on their pre-computed hash
/// (as stored in the `hashes_by_task` map).
tasks_by_key: HashMap<Key<K>, AbortHandle, S>,
/// indexed by their keys.
tasks_by_key: HashTable<(K, AbortHandle)>,

/// A map from task IDs to the hash of the key associated with that task.
///
Expand All @@ -125,21 +120,6 @@ pub struct JoinMap<K, V, S = RandomState> {
tasks: JoinSet<V>,
}

/// A [`JoinMap`] key.
///
/// This holds both a `K`-typed key (the actual key as seen by the user), _and_
/// a task ID, so that hash collisions between `K`-typed keys can be resolved
/// using either `K`'s `Eq` impl *or* by checking the task IDs.
///
/// This allows looking up a task using either an actual key (such as when the
/// user queries the map with a key), *or* using a task ID and a hash (such as
/// when removing completed tasks from the map).
#[derive(Debug)]
struct Key<K> {
key: K,
id: Id,
}

impl<K, V> JoinMap<K, V> {
/// Creates a new empty `JoinMap`.
///
Expand Down Expand Up @@ -176,7 +156,7 @@ impl<K, V> JoinMap<K, V> {
}
}

impl<K, V, S: Clone> JoinMap<K, V, S> {
impl<K, V, S> JoinMap<K, V, S> {
/// Creates an empty `JoinMap` which will use the given hash builder to hash
/// keys.
///
Expand Down Expand Up @@ -226,7 +206,7 @@ impl<K, V, S: Clone> JoinMap<K, V, S> {
#[must_use]
pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
Self {
tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()),
tasks_by_key: HashTable::with_capacity(capacity),
hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
tasks: JoinSet::new(),
}
Expand Down Expand Up @@ -416,26 +396,26 @@ where
}

fn insert(&mut self, key: K, abort: AbortHandle) {
let hash = self.hash(&key);
let hash_builder = self.hashes_by_task.hasher();
let hash = hash_one(hash_builder, &key);
let id = abort.id();
let map_key = Key { id, key };

// Insert the new key into the map of tasks by keys.
let entry = self
.tasks_by_key
.raw_entry_mut()
.from_hash(hash, |k| k.key == map_key.key);
let entry =
self.tasks_by_key
.entry(hash, |(k, _)| *k == key, |(k, _)| hash_one(hash_builder, k));
match entry {
RawEntryMut::Occupied(mut occ) => {
Entry::Occupied(occ) => {
// There was a previous task spawned with the same key! Cancel
// that task, and remove its ID from the map of hashes by task IDs.
let Key { id: prev_id, .. } = occ.insert_key(map_key);
occ.insert(abort).abort();
let _prev_hash = self.hashes_by_task.remove(&prev_id);
let (_, abort) = std::mem::replace(occ.into_mut(), (key, abort));
abort.abort();

let _prev_hash = self.hashes_by_task.remove(&abort.id());
debug_assert_eq!(Some(hash), _prev_hash);
}
RawEntryMut::Vacant(vac) => {
vac.insert(map_key, abort);
Entry::Vacant(vac) => {
vac.insert((key, abort));
}
};

Expand Down Expand Up @@ -620,7 +600,7 @@ where
// Note: this method iterates over the tasks and keys *without* removing
// any entries, so that the keys from aborted tasks can still be
// returned when calling `join_next` in the future.
for (Key { ref key, .. }, task) in &self.tasks_by_key {
for (key, task) in &self.tasks_by_key {
if predicate(key) {
task.abort();
}
Expand All @@ -635,7 +615,7 @@ where
/// [`join_next`]: fn@Self::join_next
pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
JoinMapKeys {
iter: self.tasks_by_key.keys(),
iter: self.tasks_by_key.iter(),
_value: PhantomData,
}
}
Expand Down Expand Up @@ -663,7 +643,7 @@ where
/// [`join_next`]: fn@Self::join_next
/// [task ID]: tokio::task::Id
pub fn contains_task(&self, task: &Id) -> bool {
self.get_by_id(task).is_some()
self.hashes_by_task.contains_key(task)
}

/// Reserves capacity for at least `additional` more tasks to be spawned
Expand All @@ -687,7 +667,9 @@ where
/// ```
#[inline]
pub fn reserve(&mut self, additional: usize) {
self.tasks_by_key.reserve(additional);
let hash_builder = self.hashes_by_task.hasher();
self.tasks_by_key
.reserve(additional, |(k, _)| hash_one(hash_builder, k));
self.hashes_by_task.reserve(additional);
}

Expand All @@ -713,7 +695,9 @@ where
#[inline]
pub fn shrink_to_fit(&mut self) {
self.hashes_by_task.shrink_to_fit();
self.tasks_by_key.shrink_to_fit();
let hash_builder = self.hashes_by_task.hasher();
self.tasks_by_key
.shrink_to_fit(|(k, _)| hash_one(hash_builder, k));
}

/// Shrinks the capacity of the map with a lower limit. It will drop
Expand Down Expand Up @@ -742,27 +726,20 @@ where
#[inline]
pub fn shrink_to(&mut self, min_capacity: usize) {
self.hashes_by_task.shrink_to(min_capacity);
self.tasks_by_key.shrink_to(min_capacity)
let hash_builder = self.hashes_by_task.hasher();
self.tasks_by_key
.shrink_to(min_capacity, |(k, _)| hash_one(hash_builder, k))
}

/// Look up a task in the map by its key, returning the key and abort handle.
fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key<K>, &'map AbortHandle)>
fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<&'map (K, AbortHandle)>
where
Q: Hash + Eq,
K: Borrow<Q>,
{
let hash = self.hash(key);
self.tasks_by_key
.raw_entry()
.from_hash(hash, |k| k.key.borrow() == key)
}

/// Look up a task in the map by its task ID, returning the key and abort handle.
fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key<K>, &'map AbortHandle)> {
let hash = self.hashes_by_task.get(id)?;
self.tasks_by_key
.raw_entry()
.from_hash(*hash, |k| &k.id == id)
let hash_builder = self.hashes_by_task.hasher();
let hash = hash_one(hash_builder, key);
self.tasks_by_key.find(hash, |(k, _)| k.borrow() == key)
}

/// Remove a task from the map by ID, returning the key for that task.
Expand All @@ -773,28 +750,25 @@ where
// Remove the entry for that hash.
let entry = self
.tasks_by_key
.raw_entry_mut()
.from_hash(hash, |k| k.id == id);
let (Key { id: _key_id, key }, handle) = match entry {
RawEntryMut::Occupied(entry) => entry.remove_entry(),
.find_entry(hash, |(_, abort)| abort.id() == id);
let (key, _) = match entry {
Ok(entry) => entry.remove().0,
_ => return None,
};
debug_assert_eq!(_key_id, id);
debug_assert_eq!(id, handle.id());
self.hashes_by_task.remove(&id);
Some(key)
}
}

/// Returns the hash for a given key.
#[inline]
fn hash<Q: ?Sized>(&self, key: &Q) -> u64
where
Q: Hash,
{
let mut hasher = self.tasks_by_key.hasher().build_hasher();
key.hash(&mut hasher);
hasher.finish()
}
/// Returns the hash for a given key.
#[inline]
fn hash_one<S: BuildHasher, Q: ?Sized>(hash_builder: &S, key: &Q) -> u64
where
Q: Hash,
{
let mut hasher = hash_builder.build_hasher();
key.hash(&mut hasher);
hasher.finish()
}

impl<K, V, S> JoinMap<K, V, S>
Expand Down Expand Up @@ -828,11 +802,11 @@ impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
// printing the key and task ID pairs, without format the `Key` struct
// itself or the `AbortHandle`, which would just format the task's ID
// again.
struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap<Key<K>, AbortHandle, S>);
impl<K: fmt::Debug, S> fmt::Debug for KeySet<'_, K, S> {
struct KeySet<'a, K: fmt::Debug>(&'a HashTable<(K, AbortHandle)>);
impl<K: fmt::Debug> fmt::Debug for KeySet<'_, K> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_map()
.entries(self.0.keys().map(|Key { key, id }| (key, id)))
.entries(self.0.iter().map(|(key, abort)| (key, abort.id())))
.finish()
}
}
Expand All @@ -853,31 +827,10 @@ impl<K, V> Default for JoinMap<K, V> {
}
}

// === impl Key ===

impl<K: Hash> Hash for Key<K> {
// Don't include the task ID in the hash.
#[inline]
fn hash<H: Hasher>(&self, hasher: &mut H) {
self.key.hash(hasher);
}
}

// Because we override `Hash` for this type, we must also override the
// `PartialEq` impl, so that all instances with the same hash are equal.
impl<K: PartialEq> PartialEq for Key<K> {
#[inline]
fn eq(&self, other: &Self) -> bool {
self.key == other.key
}
}

impl<K: Eq> Eq for Key<K> {}

/// An iterator over the keys of a [`JoinMap`].
#[derive(Debug, Clone)]
pub struct JoinMapKeys<'a, K, V> {
iter: hashbrown::hash_map::Keys<'a, Key<K>, AbortHandle>,
iter: hashbrown::hash_table::Iter<'a, (K, AbortHandle)>,
/// To make it easier to change `JoinMap` in the future, keep V as a generic
/// parameter.
_value: PhantomData<&'a V>,
Expand All @@ -887,7 +840,7 @@ impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
type Item = &'a K;

fn next(&mut self) -> Option<&'a K> {
self.iter.next().map(|key| &key.key)
self.iter.next().map(|(key, _)| key)
}

fn size_hint(&self) -> (usize, Option<usize>) {
Expand Down
Loading