Skip to content

Commit a7b6d44

Browse files
Simplify routing table design
Instead of doing dynamic table resizing, just allocate all the memory upfront.
1 parent 1f8cac2 commit a7b6d44

File tree

6 files changed

+110
-142
lines changed

6 files changed

+110
-142
lines changed

src/dht_service.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,13 @@ impl DhtV4 {
159159
let our_id = dht.our_id.clone();
160160

161161
info!("bootstrapping with {peer}");
162-
let response = timeout(Duration::from_secs(5), async {
162+
let _response = timeout(Duration::from_secs(5), async {
163163
let node_id = dht.ping(peer).await?;
164164
dht.routing_table.add(node_id, peer);
165165

166-
dht.find_node(our_id).await;
166+
// the find node only obviously we know ourselves, this only serves us to get us info
167+
// about other nodes
168+
let _ = dht.find_node(our_id).await;
167169
println!("done finding node");
168170

169171
Ok::<(), eyre::Report>(())

src/dht_service/dht_client.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ impl DhtHandle {
7474
pub async fn find_node(self: Arc<Self>, target: NodeId) -> Result<NodeInfo, OurError> {
7575
// if we already know the node, then no need for any network requests
7676
if let Some(node) = (&self).routing_table.find(target) {
77-
return Ok(node.contact);
77+
return Ok(node);
7878
}
7979

8080
let mut queried: HashSet<NodeInfo> = HashSet::new();
@@ -180,7 +180,7 @@ impl DhtHandle {
180180
//
181181
// if we already know the node, then no need for any network requests
182182
if let Some(node) = (&self).routing_table.find(resonsible) {
183-
let (token, _nodes, peers) = self.send_get_peers_rpc(node.contact.contact().0, info_hash).await?;
183+
let (token, _nodes, peers) = self.send_get_peers_rpc(node.contact().0, info_hash).await?;
184184
return Ok((
185185
token.expect("A node directly responsible for a piece would return a token"),
186186
peers,

src/dht_service/dht_server.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
use crate::{
22
domain_knowledge::{InfoHash, NodeId, PeerContact, Token},
33
message::{
4-
announce_peer_query::AnnouncePeerQuery,
5-
find_node_get_peers_response::{self, FindNodeGetPeersResponse},
6-
find_node_query::{self, FindNodeQuery},
7-
get_peers_query::GetPeersQuery,
8-
ping_query::PingQuery,
9-
Krpc,
4+
announce_peer_query::AnnouncePeerQuery, find_node_query::FindNodeQuery, get_peers_query::GetPeersQuery,
5+
ping_query::PingQuery, Krpc,
106
},
117
};
128
use rand::RngCore;
@@ -24,7 +20,7 @@ use tokio::{
2420
task::Builder as TskBuilder,
2521
time::Instant,
2622
};
27-
use tracing::{error, info, info_span, trace, Instrument};
23+
use tracing::{info, info_span, trace, Instrument};
2824

2925
use super::{peer_guide::PeerGuide, MessageBroker};
3026
#[derive(Debug)]

src/dht_service/peer_guide.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use tokio::sync::mpsc;
88
use crate::{
99
domain_knowledge::{NodeId, NodeInfo},
1010
message::Krpc,
11-
routing::{Node, RoutingTable},
11+
routing::RoutingTable,
1212
};
1313

1414
#[derive(Debug)]
@@ -57,7 +57,7 @@ impl PeerGuide {
5757
routing_table.find_closest(target)
5858
}
5959

60-
pub fn find(&self, target: NodeId) -> Option<Node> {
60+
pub fn find(&self, target: NodeId) -> Option<NodeInfo> {
6161
let routing_table = self.routing_table.lock().unwrap();
6262
routing_table.find(target)
6363
}

src/domain_knowledge.rs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use num::{traits::ops::bytes, BigUint};
33
use smallvec::SmallVec;
44
use std::{fmt::Debug, net::SocketAddrV4};
55

6-
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
6+
#[derive(PartialEq, Eq, Hash, Clone, Copy, PartialOrd, Ord)]
77
pub struct NodeId(pub [u8; 20]);
88

99
impl Debug for NodeId {
@@ -36,6 +36,14 @@ impl NodeId {
3636
let node_id = BigUint::from_bytes_be(rhs.as_bytes());
3737
our_id ^ node_id
3838
}
39+
40+
pub fn dist(&self, rhs: &Self) -> [u8; 20] {
41+
let mut dist = [0u8; 20];
42+
for i in 0..20 {
43+
dist[i] = self.0[i] ^ rhs.0[i]
44+
}
45+
dist
46+
}
3947
}
4048

4149
impl ToBencode for NodeId {
@@ -46,6 +54,25 @@ impl ToBencode for NodeId {
4654
}
4755
}
4856

57+
// impl PartialOrd for NodeId {
58+
// fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
59+
// for (lhs, rhs) in self.0.iter().zip(other.0.iter()) {
60+
// match lhs.cmp(rhs) {
61+
// Ordering::Equal => continue,
62+
// Ordering::Less => return Some(Ordering::Less),
63+
// Ordering::Greater => return Some(Ordering::Greater),
64+
// }
65+
// }
66+
// Some(Ordering::Equal)
67+
// }
68+
// }
69+
//
70+
// impl Ord for NodeId {
71+
// fn cmp(&self, other: &Self) -> Ordering {
72+
// self.partial_cmp(other).unwrap()
73+
// }
74+
// }
75+
4976
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
5077
pub struct InfoHash(pub [u8; 20]);
5178

src/routing.rs

Lines changed: 71 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,178 +1,121 @@
11
use crate::domain_knowledge::{NodeId, NodeInfo};
2-
use num::BigUint;
3-
use std::{ops::BitXor, str::FromStr, time::Instant};
4-
use tracing::{info, trace};
2+
use std::time::Instant;
3+
use tracing::info;
54

65
/// The routing table at the heart of the Kademlia DHT. It keep the near neighbors of ourself.
7-
#[derive(Debug)]
6+
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
87
pub struct RoutingTable {
8+
bucket_size: usize,
99
/// The node id of the ourself.
10-
id: BigUint,
10+
id: NodeId,
1111

12-
/// each bucket contains
13-
pub(crate) buckets: Vec<Bucket>,
14-
}
15-
16-
#[derive(Debug)]
17-
pub struct Bucket {
18-
/// inclusive
19-
lower_bound: BigUint,
20-
/// exclusive
21-
upper_bound: BigUint,
22-
23-
// TODO: technically a bucket is at most 8 nodes, use a fixed size vector
24-
nodes: Vec<Node>,
25-
}
26-
27-
impl Bucket {
28-
pub fn full(&self) -> bool {
29-
assert!(self.nodes.len() <= 8);
30-
self.nodes.len() >= 8
31-
}
12+
pub(crate) buckets: Box<[Option<NodeEntry>]>,
3213
}
3314

3415
#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
35-
// TODO: this name is shit, think of a better one
36-
pub struct Node {
16+
pub struct NodeEntry {
3717
pub(crate) contact: NodeInfo,
3818
pub(crate) last_checked: Instant,
3919
}
4020

4121
impl RoutingTable {
4222
pub fn new(id: NodeId) -> Self {
43-
let default_bucket = Bucket {
44-
lower_bound: BigUint::from(0u8),
45-
// 2^160
46-
upper_bound: BigUint::from_str("1461501637330902918203684832716283019655932542976").unwrap(),
47-
nodes: Vec::new(),
48-
};
49-
5023
RoutingTable {
51-
id: BigUint::from_bytes_be(id.as_bytes()),
52-
buckets: vec![default_bucket],
24+
bucket_size: 8,
25+
id,
26+
buckets: Box::new([Option::None; 160 * 8]),
5327
}
5428
}
5529

5630
pub fn node_count(&self) -> usize {
57-
self.buckets.iter().map(|b| b.nodes.len()).sum()
31+
// TODO: optimize this
32+
self.buckets.iter().filter(|n| n.is_some()).count()
5833
}
5934

6035
/// Add a new node to the routing table, if the buckets are full, the node will be ignored.
6136
pub fn add_new_node(&mut self, contact: NodeInfo) {
62-
// TODO: handle duplicate nodes
63-
6437
// there is a special case, when we already know this node, in that case, we just update the
6538
// last_checked timestamp.
66-
if let Some(node) = self
39+
let exact_match = self
6740
.buckets
6841
.iter_mut()
69-
.map(|b| b.nodes.iter_mut())
7042
.flatten()
71-
.find(|node| node.contact.id() == contact.id())
72-
{
73-
node.last_checked = Instant::now();
43+
.find(|node| node.contact.id() == contact.id());
44+
if let Some(n) = exact_match {
45+
n.last_checked = Instant::now();
7446
return;
7547
}
7648

77-
let our_id = &self.id;
78-
let distance = our_id.bitxor(BigUint::from_bytes_be(contact.id().as_bytes()));
49+
let bucket = self.bucket_for_mut(&contact.id());
50+
let slot = bucket.iter_mut().find(|n| n.is_none());
7951

80-
// first, find the bucket that this node belongs in
81-
let target_bucket = self
82-
.buckets
83-
.iter_mut()
84-
.find(|bucket| bucket.lower_bound <= distance && distance < bucket.upper_bound)
85-
.unwrap();
86-
87-
let (full, within_our_bucket) = (
88-
target_bucket.full(),
89-
&target_bucket.lower_bound <= our_id && our_id < &target_bucket.upper_bound,
90-
);
91-
match (full, within_our_bucket) {
92-
// if the bucket is full and our id is within our bucket, we need to split it
93-
(true, true) => {
94-
// split the bucket, the new bucket is the upper half of the old bucket
95-
let mut new_bucket = Bucket {
96-
lower_bound: &target_bucket.upper_bound / 2u8,
97-
upper_bound: target_bucket.upper_bound.clone(),
98-
nodes: Vec::new(),
99-
};
100-
101-
// transfer all the nodes that should go into the new bucket into the right place
102-
// do I prefer the draining_filter API? yes but that's sadly nightly only
103-
let mut i = 0;
104-
while i < target_bucket.nodes.len() {
105-
let target_bucket_node_id = BigUint::from_bytes_be(target_bucket.nodes[i].contact.id().as_bytes());
106-
if &target_bucket_node_id <= &new_bucket.lower_bound {
107-
let node = target_bucket.nodes.remove(i);
108-
new_bucket.nodes.push(node);
109-
} else {
110-
i += 1;
111-
}
112-
}
113-
114-
target_bucket.upper_bound = &target_bucket.upper_bound / 2u8;
115-
self.buckets.push(new_bucket);
116-
trace!("bucket split");
117-
}
118-
// if the bucket id range is not within our id and the bucket is full, we don't need to do
119-
// anything
120-
(true, false) => {
121-
trace!("node not added, bucket full and not within our id");
122-
}
123-
// if the buckets are not full, then happy days, we just add the new node
124-
(false, _) => {
125-
target_bucket.nodes.push(Node {
52+
// TODO: I recall there is more sophisticated to whether to ignore the insertion or not
53+
match slot {
54+
Some(inner) => {
55+
inner.replace(NodeEntry {
12656
contact,
12757
last_checked: Instant::now(),
12858
});
129-
trace!("node added");
59+
info!("{contact:?} added to routing table");
60+
return ();
61+
}
62+
None => {
63+
info!("table full, {contact:?} not added");
64+
return (); // we're full
13065
}
13166
}
132-
info!("node processed, node count: {}", self.node_count());
13367
}
13468

69+
// TODO: return an iterator instead?
13570
pub fn find_closest(&self, target: NodeId) -> Vec<NodeInfo> {
136-
let mut closest_nodes: Vec<_> = self
137-
.buckets
138-
.iter()
139-
.map(|bucket| {
140-
bucket.nodes.iter().map(|node| {
141-
let node_id = node.contact.id();
142-
let node_id = node_id.as_bytes();
143-
let target = target.as_bytes();
144-
145-
let mut distance = [0u8; 20];
146-
147-
// zip for array is sadly unstable
148-
let mut i = 0;
149-
while i < 20 {
150-
distance[i] = node_id[i] ^ target[i];
151-
i += 1;
152-
}
153-
154-
(BigUint::from_bytes_be(&distance), &node.contact)
155-
})
156-
})
157-
.flatten()
158-
.collect();
71+
let (bucket_i, bucket_i_end) = self.indices(&target);
72+
let bucket = &self.buckets[bucket_i..bucket_i_end];
15973

160-
closest_nodes.sort_unstable_by_key(|x| x.0.clone());
161-
closest_nodes
162-
.iter()
163-
.filter(|(_, node)| node.id() != target)
164-
.take(8)
165-
.map(|x| x.1)
166-
.cloned()
167-
.collect()
74+
let mut valid_entries: Vec<_> = bucket.iter().flatten().collect();
75+
valid_entries.sort_unstable_by_key(|e| e.contact.id());
76+
77+
valid_entries.into_iter().map(|n| n.contact).collect()
16878
}
16979

170-
pub fn find(&self, target: NodeId) -> Option<Node> {
80+
pub fn find(&self, target: NodeId) -> Option<NodeInfo> {
17181
self.buckets
17282
.iter()
173-
.map(|bucket| bucket.nodes.iter())
17483
.flatten()
17584
.find(|node| node.contact.id() == target)
176-
.cloned()
85+
.map(|n| n.contact)
86+
.clone()
87+
}
88+
89+
/// Returns the `index` where `self.buckets[index]` is the first entry in the corresponding
90+
/// k-bucket, and `index + (self.bucket_size - 1)` is the last entry (inclusive), so
91+
/// `index..(index + self.bucket_size)` is the valid range.
92+
fn index(&self, target: &NodeId) -> usize {
93+
// Each k-bucket at index i stores nodes of distance [2^i, 2^(i + 1)) from ourself. The
94+
// first byte in the distance where it's not zero tells us the distance falls precisely
95+
// within 2^i and 2^(i + 1).
96+
97+
let dist = self.id.dist(&target);
98+
// all zero means we're finding outself, then we go look for in the 0th bucket.
99+
let first_nonzero = dist.iter().position(|radix| *radix != 0).unwrap_or(159);
100+
(159 - first_nonzero) * self.bucket_size
101+
}
102+
103+
/// [begin, end) range of the corresponding k-bucket for `target`, the range should be accessed
104+
/// directly like `self.buckets[begin]`, the stride jumping is already done for you.
105+
fn indices(&self, target: &NodeId) -> (usize, usize) {
106+
let begin = self.index(target);
107+
let end = begin + self.bucket_size;
108+
(begin, end)
109+
}
110+
111+
#[allow(unused)]
112+
fn bucket_for(&self, target: &NodeId) -> &[Option<NodeEntry>] {
113+
let (begin, end) = self.indices(target);
114+
&self.buckets[begin..end]
115+
}
116+
117+
fn bucket_for_mut(&mut self, target: &NodeId) -> &mut [Option<NodeEntry>] {
118+
let (begin, end) = self.indices(target);
119+
&mut self.buckets[begin..end]
177120
}
178121
}

0 commit comments

Comments
 (0)