Skip to content

Commit 04258ca

Browse files
committed
Implement custom token extractor
1 parent 10a926c commit 04258ca

File tree

3 files changed

+123
-9
lines changed

3 files changed

+123
-9
lines changed

jwt-authorizer/src/authorizer.rs

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use crate::{
1414
};
1515

1616
pub type ClaimsCheckerFn<C> = Arc<Box<dyn Fn(&C) -> bool + Send + Sync>>;
17+
pub type TokenExtractorFn = Arc<Box<dyn Fn(&HeaderMap) -> Option<String> + Send + Sync>>;
1718

1819
pub struct Authorizer<C = RegisteredClaims>
1920
where
@@ -23,6 +24,7 @@ where
2324
pub claims_checker: Option<ClaimsCheckerFn<C>>,
2425
pub validation: crate::validation::Validation,
2526
pub jwt_source: JwtSource,
27+
pub token_extractor: Option<TokenExtractorFn>,
2628
}
2729

2830
fn read_data(path: &str) -> Result<Vec<u8>, InitError> {
@@ -56,6 +58,7 @@ where
5658
refresh: Option<Refresh>,
5759
validation: crate::validation::Validation,
5860
jwt_source: JwtSource,
61+
token_extractor: Option<TokenExtractorFn>,
5962
http_client: Option<Client>,
6063
) -> Result<Authorizer<C>, InitError> {
6164
Ok(match key_source_type {
@@ -77,6 +80,7 @@ where
7780
claims_checker,
7881
validation,
7982
jwt_source,
83+
token_extractor,
8084
}
8185
}
8286
KeySourceType::RSAString(text) => {
@@ -97,6 +101,7 @@ where
97101
claims_checker,
98102
validation,
99103
jwt_source,
104+
token_extractor,
100105
}
101106
}
102107
KeySourceType::EC(path) => {
@@ -110,6 +115,7 @@ where
110115
claims_checker,
111116
validation,
112117
jwt_source,
118+
token_extractor,
113119
}
114120
}
115121
KeySourceType::ECString(text) => {
@@ -123,6 +129,7 @@ where
123129
claims_checker,
124130
validation,
125131
jwt_source,
132+
token_extractor,
126133
}
127134
}
128135
KeySourceType::ED(path) => {
@@ -136,6 +143,7 @@ where
136143
claims_checker,
137144
validation,
138145
jwt_source,
146+
token_extractor,
139147
}
140148
}
141149
KeySourceType::EDString(text) => {
@@ -149,6 +157,7 @@ where
149157
claims_checker,
150158
validation,
151159
jwt_source,
160+
token_extractor,
152161
}
153162
}
154163
KeySourceType::Secret(secret) => {
@@ -162,6 +171,7 @@ where
162171
claims_checker,
163172
validation,
164173
jwt_source,
174+
token_extractor,
165175
}
166176
}
167177
KeySourceType::JwksPath(path) => {
@@ -179,6 +189,7 @@ where
179189
claims_checker,
180190
validation,
181191
jwt_source,
192+
token_extractor,
182193
}
183194
}
184195
KeySourceType::JwksString(jwks_str) => {
@@ -197,6 +208,7 @@ where
197208
claims_checker,
198209
validation,
199210
jwt_source,
211+
token_extractor,
200212
}
201213
}
202214
KeySourceType::Jwks(url) => {
@@ -207,6 +219,7 @@ where
207219
claims_checker,
208220
validation,
209221
jwt_source,
222+
token_extractor,
210223
}
211224
}
212225
KeySourceType::Discovery(issuer_url) => {
@@ -219,6 +232,7 @@ where
219232
claims_checker,
220233
validation,
221234
jwt_source,
235+
token_extractor,
222236
}
223237
}
224238
})
@@ -241,14 +255,18 @@ where
241255
}
242256

243257
pub fn extract_token(&self, h: &HeaderMap) -> Option<String> {
244-
match &self.jwt_source {
245-
layer::JwtSource::AuthorizationHeader => {
246-
let bearer_o: Option<Authorization<Bearer>> = h.typed_get();
247-
bearer_o.map(|b| String::from(b.0.token()))
258+
if let Some(ref extractor) = self.token_extractor {
259+
extractor(h)
260+
} else {
261+
match &self.jwt_source {
262+
layer::JwtSource::AuthorizationHeader => {
263+
let bearer_o: Option<Authorization<Bearer>> = h.typed_get();
264+
bearer_o.map(|b| String::from(b.0.token()))
265+
}
266+
layer::JwtSource::Cookie(name) => h
267+
.typed_get::<headers::Cookie>()
268+
.and_then(|c| c.get(name.as_str()).map(String::from)),
248269
}
249-
layer::JwtSource::Cookie(name) => h
250-
.typed_get::<headers::Cookie>()
251-
.and_then(|c| c.get(name.as_str()).map(String::from)),
252270
}
253271
}
254272
}
@@ -334,6 +352,7 @@ mod tests {
334352
Validation::new(),
335353
JwtSource::AuthorizationHeader,
336354
None,
355+
None,
337356
)
338357
.await
339358
.unwrap();
@@ -360,6 +379,7 @@ mod tests {
360379
Validation::new(),
361380
JwtSource::AuthorizationHeader,
362381
None,
382+
None,
363383
)
364384
.await
365385
.unwrap();
@@ -376,6 +396,7 @@ mod tests {
376396
Validation::new(),
377397
JwtSource::AuthorizationHeader,
378398
None,
399+
None,
379400
)
380401
.await
381402
.unwrap();
@@ -389,6 +410,7 @@ mod tests {
389410
Validation::new(),
390411
JwtSource::AuthorizationHeader,
391412
None,
413+
None,
392414
)
393415
.await
394416
.unwrap();
@@ -402,6 +424,7 @@ mod tests {
402424
Validation::new(),
403425
JwtSource::AuthorizationHeader,
404426
None,
427+
None,
405428
)
406429
.await
407430
.unwrap();
@@ -415,6 +438,7 @@ mod tests {
415438
Validation::new(),
416439
JwtSource::AuthorizationHeader,
417440
None,
441+
None,
418442
)
419443
.await
420444
.unwrap();
@@ -441,6 +465,7 @@ mod tests {
441465
Validation::new(),
442466
JwtSource::AuthorizationHeader,
443467
None,
468+
None,
444469
)
445470
.await
446471
.unwrap();
@@ -454,6 +479,7 @@ mod tests {
454479
Validation::new(),
455480
JwtSource::AuthorizationHeader,
456481
None,
482+
None,
457483
)
458484
.await
459485
.unwrap();
@@ -467,6 +493,7 @@ mod tests {
467493
Validation::new(),
468494
JwtSource::AuthorizationHeader,
469495
None,
496+
None,
470497
)
471498
.await
472499
.unwrap();
@@ -483,6 +510,7 @@ mod tests {
483510
Validation::new(),
484511
JwtSource::AuthorizationHeader,
485512
None,
513+
None,
486514
)
487515
.await;
488516
println!("{:?}", a.as_ref().err());
@@ -498,6 +526,7 @@ mod tests {
498526
Validation::default(),
499527
JwtSource::AuthorizationHeader,
500528
None,
529+
None,
501530
)
502531
.await;
503532
println!("{:?}", a.as_ref().err());
@@ -513,6 +542,7 @@ mod tests {
513542
Validation::default(),
514543
JwtSource::AuthorizationHeader,
515544
None,
545+
None,
516546
)
517547
.await;
518548
println!("{:?}", a.as_ref().err());

jwt-authorizer/src/builder.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::sync::Arc;
33
use serde::de::DeserializeOwned;
44

55
use crate::{
6-
authorizer::{ClaimsCheckerFn, KeySourceType},
6+
authorizer::{ClaimsCheckerFn, KeySourceType, TokenExtractorFn},
77
error::InitError,
88
layer::{AuthorizationLayer, JwtSource},
99
Authorizer, Refresh, RefreshStrategy, RegisteredClaims, Validation,
@@ -24,6 +24,7 @@ where
2424
claims_checker: Option<ClaimsCheckerFn<C>>,
2525
validation: Option<Validation>,
2626
jwt_source: JwtSource,
27+
token_extractor: Option<TokenExtractorFn>,
2728
http_client: Option<Client>,
2829
}
2930

@@ -43,6 +44,7 @@ where
4344
claims_checker: None,
4445
validation: None,
4546
jwt_source: JwtSource::AuthorizationHeader,
47+
token_extractor: None,
4648
http_client: None,
4749
}
4850
}
@@ -55,6 +57,7 @@ where
5557
claims_checker: None,
5658
validation: None,
5759
jwt_source: JwtSource::AuthorizationHeader,
60+
token_extractor: None,
5861
http_client: None,
5962
}
6063
}
@@ -66,6 +69,7 @@ where
6669
claims_checker: None,
6770
validation: None,
6871
jwt_source: JwtSource::AuthorizationHeader,
72+
token_extractor: None,
6973
http_client: None,
7074
}
7175
}
@@ -77,6 +81,7 @@ where
7781
claims_checker: None,
7882
validation: None,
7983
jwt_source: JwtSource::AuthorizationHeader,
84+
token_extractor: None,
8085
http_client: None,
8186
}
8287
}
@@ -89,6 +94,7 @@ where
8994
claims_checker: None,
9095
validation: None,
9196
jwt_source: JwtSource::AuthorizationHeader,
97+
token_extractor: None,
9298
http_client: None,
9399
}
94100
}
@@ -101,6 +107,7 @@ where
101107
claims_checker: None,
102108
validation: None,
103109
jwt_source: JwtSource::AuthorizationHeader,
110+
token_extractor: None,
104111
http_client: None,
105112
}
106113
}
@@ -113,6 +120,7 @@ where
113120
claims_checker: None,
114121
validation: None,
115122
jwt_source: JwtSource::AuthorizationHeader,
123+
token_extractor: None,
116124
http_client: None,
117125
}
118126
}
@@ -125,6 +133,7 @@ where
125133
claims_checker: None,
126134
validation: None,
127135
jwt_source: JwtSource::AuthorizationHeader,
136+
token_extractor: None,
128137
http_client: None,
129138
}
130139
}
@@ -137,6 +146,7 @@ where
137146
claims_checker: None,
138147
validation: None,
139148
jwt_source: JwtSource::AuthorizationHeader,
149+
token_extractor: None,
140150
http_client: None,
141151
}
142152
}
@@ -149,6 +159,7 @@ where
149159
claims_checker: None,
150160
validation: None,
151161
jwt_source: JwtSource::AuthorizationHeader,
162+
token_extractor: None,
152163
http_client: None,
153164
}
154165
}
@@ -161,6 +172,7 @@ where
161172
claims_checker: None,
162173
validation: None,
163174
jwt_source: JwtSource::AuthorizationHeader,
175+
token_extractor: None,
164176
http_client: None,
165177
}
166178
}
@@ -212,6 +224,15 @@ where
212224
self
213225
}
214226

227+
/// configures the token extractor function
228+
///
229+
/// (default: None)
230+
pub fn token_extractor(mut self, token_extractor: TokenExtractorFn) -> AuthorizerBuilder<C> {
231+
self.token_extractor = Some(token_extractor);
232+
233+
self
234+
}
235+
215236
/// provide a custom http client for oicd requests
216237
/// if not called, uses a default configured client
217238
///
@@ -233,6 +254,7 @@ where
233254
self.refresh,
234255
val,
235256
self.jwt_source,
257+
self.token_extractor,
236258
None,
237259
)
238260
.await?,
@@ -249,6 +271,7 @@ where
249271
self.refresh,
250272
val,
251273
self.jwt_source,
274+
self.token_extractor,
252275
self.http_client,
253276
)
254277
.await

0 commit comments

Comments
 (0)