Skip to content

Commit 60a6f39

Browse files
authored
Context as type map (#504)
* Context as type map * Added header injector policy * Rylev's suggestions and clippy * removed iterator, pub internal HeaderMap
1 parent 8edf216 commit 60a6f39

File tree

6 files changed

+233
-5
lines changed

6 files changed

+233
-5
lines changed

sdk/core/src/context.rs

Lines changed: 167 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,175 @@
1+
use std::any::{Any, TypeId};
2+
use std::collections::HashMap;
3+
use std::sync::Arc;
4+
15
/// Pipeline execution context.
2-
#[derive(Clone, Debug, Default)]
6+
#[derive(Clone, Debug)]
37
pub struct Context {
4-
_priv: (),
8+
type_map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
9+
}
10+
11+
impl Default for Context {
12+
fn default() -> Self {
13+
Self::new()
14+
}
515
}
616

717
impl Context {
18+
/// Creates a new, empty Context.
819
pub fn new() -> Self {
9-
Self::default()
20+
Self {
21+
type_map: HashMap::new(),
22+
}
23+
}
24+
25+
/// Inserts or replaces an entity in the type map. If an entity with the same type was displaced
26+
/// by the insert, it will be returned to the caller.
27+
pub fn insert_or_replace<E>(&mut self, entity: E) -> Option<Arc<E>>
28+
where
29+
E: Send + Sync + 'static,
30+
{
31+
// we make sure that for every TypeId of E as key we ALWAYS retrieve an Option<Arc<E>>. That's why
32+
// the `unwrap` below is safe.
33+
self.type_map
34+
.insert(TypeId::of::<E>(), Arc::new(entity))
35+
.map(|displaced| displaced.downcast().unwrap())
36+
}
37+
38+
/// Inserts an entity in the type map. If the an entity with the same type signature is
39+
/// already present it will be silently dropped. This function returns a mutable reference to
40+
/// the same Context so it can be chained to itself.
41+
pub fn insert<E>(&mut self, entity: E) -> &mut Self
42+
where
43+
E: Send + Sync + 'static,
44+
{
45+
self.type_map.insert(TypeId::of::<E>(), Arc::new(entity));
46+
47+
self
48+
}
49+
50+
/// Removes an entity from the type map. If present, the entity will be returned.
51+
pub fn remove<E>(&mut self) -> Option<Arc<E>>
52+
where
53+
E: Send + Sync + 'static,
54+
{
55+
self.type_map
56+
.remove(&TypeId::of::<E>())
57+
.map(|removed| removed.downcast().unwrap())
58+
}
59+
60+
/// Returns a reference of the entity of the specified type signature, if it exists.
61+
///
62+
/// If there is no entity with the specific type signature, `None` is returned instead.
63+
pub fn get<E>(&self) -> Option<&E>
64+
where
65+
E: Send + Sync + 'static,
66+
{
67+
self.type_map
68+
.get(&TypeId::of::<E>())
69+
.map(|item| item.downcast_ref())
70+
.flatten()
71+
}
72+
73+
/// Returns the number of entities in the type map.
74+
pub fn len(&self) -> usize {
75+
self.type_map.len()
76+
}
77+
78+
/// Returns `true` if the type map is empty, `false` otherwise.
79+
pub fn is_empty(&self) -> bool {
80+
self.type_map.is_empty()
81+
}
82+
}
83+
84+
#[cfg(test)]
85+
mod tests {
86+
use super::*;
87+
use std::sync::Mutex;
88+
89+
#[test]
90+
fn insert_get_string() {
91+
let mut context = Context::new();
92+
context.insert_or_replace("pollo".to_string());
93+
assert_eq!(Some(&"pollo".to_string()), context.get());
94+
}
95+
96+
#[test]
97+
fn insert_get_custom_structs() {
98+
#[derive(Debug, PartialEq, Eq)]
99+
struct S1 {}
100+
#[derive(Debug, PartialEq, Eq)]
101+
struct S2 {}
102+
103+
let mut context = Context::new();
104+
context.insert_or_replace(S1 {});
105+
context.insert_or_replace(S2 {});
106+
107+
assert_eq!(Some(Arc::new(S1 {})), context.insert_or_replace(S1 {}));
108+
assert_eq!(Some(Arc::new(S2 {})), context.insert_or_replace(S2 {}));
109+
110+
assert_eq!(Some(&S1 {}), context.get());
111+
assert_eq!(Some(&S2 {}), context.get());
112+
}
113+
114+
#[test]
115+
fn insert_fluent_syntax() {
116+
#[derive(Debug, PartialEq, Eq, Default)]
117+
struct S1 {}
118+
#[derive(Debug, PartialEq, Eq, Default)]
119+
struct S2 {}
120+
121+
let mut context = Context::new();
122+
123+
context
124+
.insert("static str")
125+
.insert("a String".to_string())
126+
.insert(S1::default())
127+
.insert(S1::default()) // notice we are REPLACING S1. This call will *not* increment the counter
128+
.insert(S2::default());
129+
130+
assert_eq!(4, context.len());
131+
assert_eq!(Some(&"static str"), context.get());
132+
}
133+
134+
fn require_send_sync<T: Send + Sync>(_: &T) {}
135+
136+
#[test]
137+
fn test_require_send_sync() {
138+
// this won't compile if Context as a whole is not Send + Sync
139+
require_send_sync(&Context::new())
140+
}
141+
142+
#[test]
143+
fn mutability() {
144+
#[derive(Debug, PartialEq, Eq, Default)]
145+
struct S1 {
146+
num: u8,
147+
}
148+
let mut context = Context::new();
149+
context.insert_or_replace(Mutex::new(S1::default()));
150+
151+
// the stored value is 0.
152+
assert_eq!(0, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
153+
154+
// we change the number to 42 in a thread safe manner.
155+
context.get::<Mutex<S1>>().unwrap().lock().unwrap().num = 42;
156+
157+
// now the number is 42.
158+
assert_eq!(42, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
159+
160+
// we replace the struct with a new one.
161+
let displaced = context
162+
.insert_or_replace(Mutex::new(S1::default()))
163+
.unwrap();
164+
165+
// the displaced struct still holds 42 as number
166+
assert_eq!(42, displaced.lock().unwrap().num);
167+
168+
// the new struct has 0 has number.
169+
assert_eq!(0, context.get::<Mutex<S1>>().unwrap().lock().unwrap().num);
170+
171+
context.insert_or_replace(Mutex::new(33u32));
172+
*context.get::<Mutex<u32>>().unwrap().lock().unwrap() = 42;
173+
assert_eq!(42, *context.get::<Mutex<u32>>().unwrap().lock().unwrap());
10174
}
11175
}

sdk/core/src/pipeline.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#[cfg(not(target_arch = "wasm32"))]
22
use crate::policies::TransportPolicy;
3-
use crate::policies::{Policy, TelemetryPolicy};
3+
use crate::policies::{CustomHeadersInjectorPolicy, Policy, TelemetryPolicy};
44
use crate::{ClientOptions, Error, HttpClient, PipelineContext, Request, Response};
55
use std::sync::Arc;
66

@@ -70,6 +70,8 @@ where
7070
let telemetry_policy = TelemetryPolicy::new(crate_name, crate_version, &options.telemetry);
7171
pipeline.push(Arc::new(telemetry_policy));
7272

73+
pipeline.push(Arc::new(CustomHeadersInjectorPolicy::default()));
74+
7375
let retry_policy = options.retry.to_policy();
7476
pipeline.push(retry_policy);
7577

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
use crate::policies::{Policy, PolicyResult};
2+
use crate::{PipelineContext, Request, Response};
3+
use http::header::HeaderMap;
4+
use std::sync::Arc;
5+
6+
#[derive(Debug, Clone, PartialEq, Eq, Default)]
7+
pub struct CustomHeaders(pub HeaderMap);
8+
9+
impl From<HeaderMap> for CustomHeaders {
10+
fn from(header_map: HeaderMap) -> Self {
11+
Self(header_map)
12+
}
13+
}
14+
15+
#[derive(Clone, Debug, Default)]
16+
pub struct CustomHeadersInjectorPolicy {}
17+
18+
#[async_trait::async_trait]
19+
impl<C> Policy<C> for CustomHeadersInjectorPolicy
20+
where
21+
C: Send + Sync,
22+
{
23+
async fn send(
24+
&self,
25+
ctx: &mut PipelineContext<C>,
26+
request: &mut Request,
27+
next: &[Arc<dyn Policy<C>>],
28+
) -> PolicyResult<Response> {
29+
if let Some(CustomHeaders(custom_headers)) = ctx.get_inner_context().get::<CustomHeaders>()
30+
{
31+
custom_headers
32+
.iter()
33+
.for_each(|(header_name, header_value)| {
34+
log::trace!(
35+
"injecting custom context header {:?} with value {:?}",
36+
header_name,
37+
header_value
38+
);
39+
request
40+
.headers_mut()
41+
.insert(header_name, header_value.to_owned());
42+
});
43+
}
44+
45+
next[0].send(ctx, request, &next[1..]).await
46+
}
47+
}

sdk/core/src/policies/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
mod custom_headers_injector_policy;
12
#[cfg(feature = "mock_transport_framework")]
23
mod mock_transport_player_policy;
34
#[cfg(feature = "mock_transport_framework")]
@@ -7,6 +8,7 @@ mod telemetry_policy;
78
mod transport;
89

910
use crate::{PipelineContext, Request, Response};
11+
pub use custom_headers_injector_policy::{CustomHeaders, CustomHeadersInjectorPolicy};
1012
#[cfg(feature = "mock_transport_framework")]
1113
pub use mock_transport_player_policy::MockTransportPlayerPolicy;
1214
#[cfg(feature = "mock_transport_framework")]

sdk/core/src/prelude.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
pub use crate::etag::Etag;
2+
pub use crate::policies::CustomHeaders;
23
pub use crate::request_options::*;
34
pub use crate::{
45
new_http_client, AddAsHeader, AppendToUrlQuery, Context, HttpClient, RequestId, SessionToken,

sdk/cosmos/examples/get_database.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use azure_core::prelude::*;
22
use azure_cosmos::prelude::*;
3+
use http::{HeaderMap, HeaderValue};
34
use std::error::Error;
45

56
#[tokio::main]
@@ -24,8 +25,19 @@ async fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
2425

2526
let database_client = client.into_database_client(database_name.clone());
2627

28+
let mut context = Context::new();
29+
30+
// Next we create a CustomHeaders type and insert it into the context allowing us to insert custom headers.
31+
let custom_headers: CustomHeaders = {
32+
let mut custom_headers = HeaderMap::new();
33+
custom_headers.insert("MyCoolHeader", HeaderValue::from_static("CORS maybe?"));
34+
custom_headers.into()
35+
};
36+
37+
context.insert(custom_headers);
38+
2739
let response = database_client
28-
.get_database(Context::new(), GetDatabaseOptions::new())
40+
.get_database(context, GetDatabaseOptions::new())
2941
.await?;
3042
println!("response == {:?}", response);
3143

0 commit comments

Comments
 (0)