@@ -14,6 +14,7 @@ use crate::{
14
14
} ;
15
15
16
16
pub type ClaimsCheckerFn < C > = Arc < Box < dyn Fn ( & C ) -> bool + Send + Sync > > ;
17
+ pub type TokenExtractorFn = Arc < Box < dyn Fn ( & HeaderMap ) -> Option < String > + Send + Sync > > ;
17
18
18
19
pub struct Authorizer < C = RegisteredClaims >
19
20
where
23
24
pub claims_checker : Option < ClaimsCheckerFn < C > > ,
24
25
pub validation : crate :: validation:: Validation ,
25
26
pub jwt_source : JwtSource ,
27
+ pub token_extractor : Option < TokenExtractorFn > ,
26
28
}
27
29
28
30
fn read_data ( path : & str ) -> Result < Vec < u8 > , InitError > {
56
58
refresh : Option < Refresh > ,
57
59
validation : crate :: validation:: Validation ,
58
60
jwt_source : JwtSource ,
61
+ token_extractor : Option < TokenExtractorFn > ,
59
62
http_client : Option < Client > ,
60
63
) -> Result < Authorizer < C > , InitError > {
61
64
Ok ( match key_source_type {
77
80
claims_checker,
78
81
validation,
79
82
jwt_source,
83
+ token_extractor,
80
84
}
81
85
}
82
86
KeySourceType :: RSAString ( text) => {
97
101
claims_checker,
98
102
validation,
99
103
jwt_source,
104
+ token_extractor,
100
105
}
101
106
}
102
107
KeySourceType :: EC ( path) => {
@@ -110,6 +115,7 @@ where
110
115
claims_checker,
111
116
validation,
112
117
jwt_source,
118
+ token_extractor,
113
119
}
114
120
}
115
121
KeySourceType :: ECString ( text) => {
@@ -123,6 +129,7 @@ where
123
129
claims_checker,
124
130
validation,
125
131
jwt_source,
132
+ token_extractor,
126
133
}
127
134
}
128
135
KeySourceType :: ED ( path) => {
@@ -136,6 +143,7 @@ where
136
143
claims_checker,
137
144
validation,
138
145
jwt_source,
146
+ token_extractor,
139
147
}
140
148
}
141
149
KeySourceType :: EDString ( text) => {
@@ -149,6 +157,7 @@ where
149
157
claims_checker,
150
158
validation,
151
159
jwt_source,
160
+ token_extractor,
152
161
}
153
162
}
154
163
KeySourceType :: Secret ( secret) => {
@@ -162,6 +171,7 @@ where
162
171
claims_checker,
163
172
validation,
164
173
jwt_source,
174
+ token_extractor,
165
175
}
166
176
}
167
177
KeySourceType :: JwksPath ( path) => {
@@ -179,6 +189,7 @@ where
179
189
claims_checker,
180
190
validation,
181
191
jwt_source,
192
+ token_extractor,
182
193
}
183
194
}
184
195
KeySourceType :: JwksString ( jwks_str) => {
@@ -197,6 +208,7 @@ where
197
208
claims_checker,
198
209
validation,
199
210
jwt_source,
211
+ token_extractor,
200
212
}
201
213
}
202
214
KeySourceType :: Jwks ( url) => {
@@ -207,6 +219,7 @@ where
207
219
claims_checker,
208
220
validation,
209
221
jwt_source,
222
+ token_extractor,
210
223
}
211
224
}
212
225
KeySourceType :: Discovery ( issuer_url) => {
@@ -219,6 +232,7 @@ where
219
232
claims_checker,
220
233
validation,
221
234
jwt_source,
235
+ token_extractor,
222
236
}
223
237
}
224
238
} )
@@ -241,14 +255,18 @@ where
241
255
}
242
256
243
257
pub fn extract_token ( & self , h : & HeaderMap ) -> Option < String > {
244
- match & self . jwt_source {
245
- layer:: JwtSource :: AuthorizationHeader => {
246
- let bearer_o: Option < Authorization < Bearer > > = h. typed_get ( ) ;
247
- bearer_o. map ( |b| String :: from ( b. 0 . token ( ) ) )
258
+ if let Some ( ref extractor) = self . token_extractor {
259
+ extractor ( h)
260
+ } else {
261
+ match & self . jwt_source {
262
+ layer:: JwtSource :: AuthorizationHeader => {
263
+ let bearer_o: Option < Authorization < Bearer > > = h. typed_get ( ) ;
264
+ bearer_o. map ( |b| String :: from ( b. 0 . token ( ) ) )
265
+ }
266
+ layer:: JwtSource :: Cookie ( name) => h
267
+ . typed_get :: < headers:: Cookie > ( )
268
+ . and_then ( |c| c. get ( name. as_str ( ) ) . map ( String :: from) ) ,
248
269
}
249
- layer:: JwtSource :: Cookie ( name) => h
250
- . typed_get :: < headers:: Cookie > ( )
251
- . and_then ( |c| c. get ( name. as_str ( ) ) . map ( String :: from) ) ,
252
270
}
253
271
}
254
272
}
@@ -334,6 +352,7 @@ mod tests {
334
352
Validation :: new ( ) ,
335
353
JwtSource :: AuthorizationHeader ,
336
354
None ,
355
+ None ,
337
356
)
338
357
. await
339
358
. unwrap ( ) ;
@@ -360,6 +379,7 @@ mod tests {
360
379
Validation :: new ( ) ,
361
380
JwtSource :: AuthorizationHeader ,
362
381
None ,
382
+ None ,
363
383
)
364
384
. await
365
385
. unwrap ( ) ;
@@ -376,6 +396,7 @@ mod tests {
376
396
Validation :: new ( ) ,
377
397
JwtSource :: AuthorizationHeader ,
378
398
None ,
399
+ None ,
379
400
)
380
401
. await
381
402
. unwrap ( ) ;
@@ -389,6 +410,7 @@ mod tests {
389
410
Validation :: new ( ) ,
390
411
JwtSource :: AuthorizationHeader ,
391
412
None ,
413
+ None ,
392
414
)
393
415
. await
394
416
. unwrap ( ) ;
@@ -402,6 +424,7 @@ mod tests {
402
424
Validation :: new ( ) ,
403
425
JwtSource :: AuthorizationHeader ,
404
426
None ,
427
+ None ,
405
428
)
406
429
. await
407
430
. unwrap ( ) ;
@@ -415,6 +438,7 @@ mod tests {
415
438
Validation :: new ( ) ,
416
439
JwtSource :: AuthorizationHeader ,
417
440
None ,
441
+ None ,
418
442
)
419
443
. await
420
444
. unwrap ( ) ;
@@ -441,6 +465,7 @@ mod tests {
441
465
Validation :: new ( ) ,
442
466
JwtSource :: AuthorizationHeader ,
443
467
None ,
468
+ None ,
444
469
)
445
470
. await
446
471
. unwrap ( ) ;
@@ -454,6 +479,7 @@ mod tests {
454
479
Validation :: new ( ) ,
455
480
JwtSource :: AuthorizationHeader ,
456
481
None ,
482
+ None ,
457
483
)
458
484
. await
459
485
. unwrap ( ) ;
@@ -467,6 +493,7 @@ mod tests {
467
493
Validation :: new ( ) ,
468
494
JwtSource :: AuthorizationHeader ,
469
495
None ,
496
+ None ,
470
497
)
471
498
. await
472
499
. unwrap ( ) ;
@@ -483,6 +510,7 @@ mod tests {
483
510
Validation :: new ( ) ,
484
511
JwtSource :: AuthorizationHeader ,
485
512
None ,
513
+ None ,
486
514
)
487
515
. await ;
488
516
println ! ( "{:?}" , a. as_ref( ) . err( ) ) ;
@@ -498,6 +526,7 @@ mod tests {
498
526
Validation :: default ( ) ,
499
527
JwtSource :: AuthorizationHeader ,
500
528
None ,
529
+ None ,
501
530
)
502
531
. await ;
503
532
println ! ( "{:?}" , a. as_ref( ) . err( ) ) ;
@@ -513,6 +542,7 @@ mod tests {
513
542
Validation :: default ( ) ,
514
543
JwtSource :: AuthorizationHeader ,
515
544
None ,
545
+ None ,
516
546
)
517
547
. await ;
518
548
println ! ( "{:?}" , a. as_ref( ) . err( ) ) ;
0 commit comments