Skip to content

Implement custom token extractor #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions jwt-authorizer/src/authorizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::{
};

pub type ClaimsCheckerFn<C> = Arc<Box<dyn Fn(&C) -> bool + Send + Sync>>;
pub type TokenExtractorFn = Arc<Box<dyn Fn(&HeaderMap) -> Option<String> + Send + Sync>>;

pub struct Authorizer<C = RegisteredClaims>
where
Expand All @@ -23,6 +24,7 @@ where
pub claims_checker: Option<ClaimsCheckerFn<C>>,
pub validation: crate::validation::Validation,
pub jwt_source: JwtSource,
pub token_extractor: Option<TokenExtractorFn>,
}

fn read_data(path: &str) -> Result<Vec<u8>, InitError> {
Expand Down Expand Up @@ -56,6 +58,7 @@ where
refresh: Option<Refresh>,
validation: crate::validation::Validation,
jwt_source: JwtSource,
token_extractor: Option<TokenExtractorFn>,
http_client: Option<Client>,
) -> Result<Authorizer<C>, InitError> {
Ok(match key_source_type {
Expand All @@ -77,6 +80,7 @@ where
claims_checker,
validation,
jwt_source,
token_extractor,
}
}
KeySourceType::RSAString(text) => {
Expand All @@ -97,6 +101,7 @@ where
claims_checker,
validation,
jwt_source,
token_extractor,
}
}
KeySourceType::EC(path) => {
Expand All @@ -110,6 +115,7 @@ where
claims_checker,
validation,
jwt_source,
token_extractor,
}
}
KeySourceType::ECString(text) => {
Expand All @@ -123,6 +129,7 @@ where
claims_checker,
validation,
jwt_source,
token_extractor,
}
}
KeySourceType::ED(path) => {
Expand All @@ -136,6 +143,7 @@ where
claims_checker,
validation,
jwt_source,
token_extractor,
}
}
KeySourceType::EDString(text) => {
Expand All @@ -149,6 +157,7 @@ where
claims_checker,
validation,
jwt_source,
token_extractor,
}
}
KeySourceType::Secret(secret) => {
Expand All @@ -162,6 +171,7 @@ where
claims_checker,
validation,
jwt_source,
token_extractor,
}
}
KeySourceType::JwksPath(path) => {
Expand All @@ -179,6 +189,7 @@ where
claims_checker,
validation,
jwt_source,
token_extractor,
}
}
KeySourceType::JwksString(jwks_str) => {
Expand All @@ -197,6 +208,7 @@ where
claims_checker,
validation,
jwt_source,
token_extractor,
}
}
KeySourceType::Jwks(url) => {
Expand All @@ -207,6 +219,7 @@ where
claims_checker,
validation,
jwt_source,
token_extractor,
}
}
KeySourceType::Discovery(issuer_url) => {
Expand All @@ -219,6 +232,7 @@ where
claims_checker,
validation,
jwt_source,
token_extractor,
}
}
})
Expand All @@ -241,14 +255,18 @@ where
}

pub fn extract_token(&self, h: &HeaderMap) -> Option<String> {
match &self.jwt_source {
layer::JwtSource::AuthorizationHeader => {
let bearer_o: Option<Authorization<Bearer>> = h.typed_get();
bearer_o.map(|b| String::from(b.0.token()))
if let Some(ref extractor) = self.token_extractor {
extractor(h)
} else {
match &self.jwt_source {
layer::JwtSource::AuthorizationHeader => {
let bearer_o: Option<Authorization<Bearer>> = h.typed_get();
bearer_o.map(|b| String::from(b.0.token()))
}
layer::JwtSource::Cookie(name) => h
.typed_get::<headers::Cookie>()
.and_then(|c| c.get(name.as_str()).map(String::from)),
}
layer::JwtSource::Cookie(name) => h
.typed_get::<headers::Cookie>()
.and_then(|c| c.get(name.as_str()).map(String::from)),
}
}
}
Expand Down Expand Up @@ -334,6 +352,7 @@ mod tests {
Validation::new(),
JwtSource::AuthorizationHeader,
None,
None,
)
.await
.unwrap();
Expand All @@ -360,6 +379,7 @@ mod tests {
Validation::new(),
JwtSource::AuthorizationHeader,
None,
None,
)
.await
.unwrap();
Expand All @@ -376,6 +396,7 @@ mod tests {
Validation::new(),
JwtSource::AuthorizationHeader,
None,
None,
)
.await
.unwrap();
Expand All @@ -389,6 +410,7 @@ mod tests {
Validation::new(),
JwtSource::AuthorizationHeader,
None,
None,
)
.await
.unwrap();
Expand All @@ -402,6 +424,7 @@ mod tests {
Validation::new(),
JwtSource::AuthorizationHeader,
None,
None,
)
.await
.unwrap();
Expand All @@ -415,6 +438,7 @@ mod tests {
Validation::new(),
JwtSource::AuthorizationHeader,
None,
None,
)
.await
.unwrap();
Expand All @@ -441,6 +465,7 @@ mod tests {
Validation::new(),
JwtSource::AuthorizationHeader,
None,
None,
)
.await
.unwrap();
Expand All @@ -454,6 +479,7 @@ mod tests {
Validation::new(),
JwtSource::AuthorizationHeader,
None,
None,
)
.await
.unwrap();
Expand All @@ -467,6 +493,7 @@ mod tests {
Validation::new(),
JwtSource::AuthorizationHeader,
None,
None,
)
.await
.unwrap();
Expand All @@ -483,6 +510,7 @@ mod tests {
Validation::new(),
JwtSource::AuthorizationHeader,
None,
None,
)
.await;
println!("{:?}", a.as_ref().err());
Expand All @@ -498,6 +526,7 @@ mod tests {
Validation::default(),
JwtSource::AuthorizationHeader,
None,
None,
)
.await;
println!("{:?}", a.as_ref().err());
Expand All @@ -513,6 +542,7 @@ mod tests {
Validation::default(),
JwtSource::AuthorizationHeader,
None,
None,
)
.await;
println!("{:?}", a.as_ref().err());
Expand Down
25 changes: 24 additions & 1 deletion jwt-authorizer/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use serde::de::DeserializeOwned;

use crate::{
authorizer::{ClaimsCheckerFn, KeySourceType},
authorizer::{ClaimsCheckerFn, KeySourceType, TokenExtractorFn},
error::InitError,
layer::{AuthorizationLayer, JwtSource},
Authorizer, Refresh, RefreshStrategy, RegisteredClaims, Validation,
Expand All @@ -24,6 +24,7 @@ where
claims_checker: Option<ClaimsCheckerFn<C>>,
validation: Option<Validation>,
jwt_source: JwtSource,
token_extractor: Option<TokenExtractorFn>,
http_client: Option<Client>,
}

Expand All @@ -43,6 +44,7 @@ where
claims_checker: None,
validation: None,
jwt_source: JwtSource::AuthorizationHeader,
token_extractor: None,
http_client: None,
}
}
Expand All @@ -55,6 +57,7 @@ where
claims_checker: None,
validation: None,
jwt_source: JwtSource::AuthorizationHeader,
token_extractor: None,
http_client: None,
}
}
Expand All @@ -66,6 +69,7 @@ where
claims_checker: None,
validation: None,
jwt_source: JwtSource::AuthorizationHeader,
token_extractor: None,
http_client: None,
}
}
Expand All @@ -77,6 +81,7 @@ where
claims_checker: None,
validation: None,
jwt_source: JwtSource::AuthorizationHeader,
token_extractor: None,
http_client: None,
}
}
Expand All @@ -89,6 +94,7 @@ where
claims_checker: None,
validation: None,
jwt_source: JwtSource::AuthorizationHeader,
token_extractor: None,
http_client: None,
}
}
Expand All @@ -101,6 +107,7 @@ where
claims_checker: None,
validation: None,
jwt_source: JwtSource::AuthorizationHeader,
token_extractor: None,
http_client: None,
}
}
Expand All @@ -113,6 +120,7 @@ where
claims_checker: None,
validation: None,
jwt_source: JwtSource::AuthorizationHeader,
token_extractor: None,
http_client: None,
}
}
Expand All @@ -125,6 +133,7 @@ where
claims_checker: None,
validation: None,
jwt_source: JwtSource::AuthorizationHeader,
token_extractor: None,
http_client: None,
}
}
Expand All @@ -137,6 +146,7 @@ where
claims_checker: None,
validation: None,
jwt_source: JwtSource::AuthorizationHeader,
token_extractor: None,
http_client: None,
}
}
Expand All @@ -149,6 +159,7 @@ where
claims_checker: None,
validation: None,
jwt_source: JwtSource::AuthorizationHeader,
token_extractor: None,
http_client: None,
}
}
Expand All @@ -161,6 +172,7 @@ where
claims_checker: None,
validation: None,
jwt_source: JwtSource::AuthorizationHeader,
token_extractor: None,
http_client: None,
}
}
Expand Down Expand Up @@ -212,6 +224,15 @@ where
self
}

/// configures the token extractor function
///
/// (default: None)
pub fn token_extractor(mut self, token_extractor: TokenExtractorFn) -> AuthorizerBuilder<C> {
self.token_extractor = Some(token_extractor);

self
}

/// provide a custom http client for oicd requests
/// if not called, uses a default configured client
///
Expand All @@ -233,6 +254,7 @@ where
self.refresh,
val,
self.jwt_source,
self.token_extractor,
None,
)
.await?,
Expand All @@ -249,6 +271,7 @@ where
self.refresh,
val,
self.jwt_source,
self.token_extractor,
self.http_client,
)
.await
Expand Down
4 changes: 1 addition & 3 deletions jwt-authorizer/src/jwks/key_store_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ impl KeyStoreManager {
)],
)
.await?;
ks_gard
.find_alg(&header.alg)
.ok_or_else(|| AuthError::InvalidKeyAlg(header.alg))?
ks_gard.find_alg(&header.alg).ok_or(AuthError::InvalidKeyAlg(header.alg))?
} else {
return Err(AuthError::InvalidKeyAlg(header.alg));
}
Expand Down
Loading
Loading