Skip to content

Commit 60222d0

Browse files
committed
Add a MutGlobal utility
Can be used for specifying dynamic components for tests, such as the signer.
1 parent adf2f68 commit 60222d0

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

lightning/src/util/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@ pub(crate) mod fuzz_wrappers;
1515
#[macro_use]
1616
pub mod ser_macros;
1717

18+
#[cfg(any(test, feature = "_test_utils"))]
19+
pub mod mut_global;
20+
1821
pub mod anchor_channel_reserves;
22+
1923
#[cfg(fuzzing)]
2024
pub mod base32;
2125
#[cfg(not(fuzzing))]

lightning/src/util/mut_global.rs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//! A settable global variable.
2+
//!
3+
//! Used for testing purposes only.
4+
5+
use std::sync::Mutex;
6+
7+
/// A global variable that can be set exactly once.
8+
pub struct MutGlobal<T> {
9+
value: Mutex<Option<T>>,
10+
default_fn: fn() -> T,
11+
}
12+
13+
impl<T: Clone> MutGlobal<T> {
14+
/// Create a new `MutGlobal` with no value set.
15+
///
16+
/// default_fn will be called to get the default value if the value is unset
17+
/// at the time the first call to `get` is made.
18+
pub const fn new(default_fn: fn() -> T) -> Self {
19+
Self { value: Mutex::new(None), default_fn }
20+
}
21+
22+
/// Set the value of the global variable.
23+
pub fn set(&self, value: T) {
24+
let mut lock = self.value.lock().unwrap();
25+
*lock = Some(value);
26+
}
27+
28+
/// Get the value of the global variable, or the default if unset.
29+
pub fn get(&self) -> T {
30+
let mut lock = self.value.lock().unwrap();
31+
if let Some(value) = &*lock {
32+
value.clone()
33+
} else {
34+
let value = (self.default_fn)();
35+
*lock = Some(value.clone());
36+
value
37+
}
38+
}
39+
}
40+
41+
#[cfg(test)]
42+
mod tests {
43+
use super::*;
44+
45+
#[test]
46+
fn test() {
47+
let v = MutGlobal::<u8>::new(|| 0);
48+
assert_eq!(v.get(), 0);
49+
v.set(42);
50+
assert_eq!(v.get(), 42);
51+
v.set(43);
52+
assert_eq!(v.get(), 43);
53+
}
54+
55+
static G: MutGlobal<u8> = MutGlobal::new(|| 0);
56+
57+
#[test]
58+
fn test_global() {
59+
G.set(42);
60+
assert_eq!(G.get(), 42);
61+
G.set(43);
62+
assert_eq!(G.get(), 43);
63+
}
64+
}

0 commit comments

Comments
 (0)