From 04258ca89714458ff8b2baa7c5c59f1990bc2734 Mon Sep 17 00:00:00 2001 From: Vladislav Manchev Date: Thu, 25 Apr 2024 20:21:35 +0300 Subject: [PATCH 1/4] Implement custom token extractor --- jwt-authorizer/src/authorizer.rs | 44 ++++++++++++++++++---- jwt-authorizer/src/builder.rs | 25 ++++++++++++- jwt-authorizer/tests/tests.rs | 63 +++++++++++++++++++++++++++++++- 3 files changed, 123 insertions(+), 9 deletions(-) diff --git a/jwt-authorizer/src/authorizer.rs b/jwt-authorizer/src/authorizer.rs index d80963c..e457882 100644 --- a/jwt-authorizer/src/authorizer.rs +++ b/jwt-authorizer/src/authorizer.rs @@ -14,6 +14,7 @@ use crate::{ }; pub type ClaimsCheckerFn = Arc bool + Send + Sync>>; +pub type TokenExtractorFn = Arc Option + Send + Sync>>; pub struct Authorizer where @@ -23,6 +24,7 @@ where pub claims_checker: Option>, pub validation: crate::validation::Validation, pub jwt_source: JwtSource, + pub token_extractor: Option, } fn read_data(path: &str) -> Result, InitError> { @@ -56,6 +58,7 @@ where refresh: Option, validation: crate::validation::Validation, jwt_source: JwtSource, + token_extractor: Option, http_client: Option, ) -> Result, InitError> { Ok(match key_source_type { @@ -77,6 +80,7 @@ where claims_checker, validation, jwt_source, + token_extractor, } } KeySourceType::RSAString(text) => { @@ -97,6 +101,7 @@ where claims_checker, validation, jwt_source, + token_extractor, } } KeySourceType::EC(path) => { @@ -110,6 +115,7 @@ where claims_checker, validation, jwt_source, + token_extractor, } } KeySourceType::ECString(text) => { @@ -123,6 +129,7 @@ where claims_checker, validation, jwt_source, + token_extractor, } } KeySourceType::ED(path) => { @@ -136,6 +143,7 @@ where claims_checker, validation, jwt_source, + token_extractor, } } KeySourceType::EDString(text) => { @@ -149,6 +157,7 @@ where claims_checker, validation, jwt_source, + token_extractor, } } KeySourceType::Secret(secret) => { @@ -162,6 +171,7 @@ where claims_checker, validation, jwt_source, + token_extractor, } } KeySourceType::JwksPath(path) => { @@ -179,6 +189,7 @@ where claims_checker, validation, jwt_source, + token_extractor, } } KeySourceType::JwksString(jwks_str) => { @@ -197,6 +208,7 @@ where claims_checker, validation, jwt_source, + token_extractor, } } KeySourceType::Jwks(url) => { @@ -207,6 +219,7 @@ where claims_checker, validation, jwt_source, + token_extractor, } } KeySourceType::Discovery(issuer_url) => { @@ -219,6 +232,7 @@ where claims_checker, validation, jwt_source, + token_extractor, } } }) @@ -241,14 +255,18 @@ where } pub fn extract_token(&self, h: &HeaderMap) -> Option { - match &self.jwt_source { - layer::JwtSource::AuthorizationHeader => { - let bearer_o: Option> = h.typed_get(); - bearer_o.map(|b| String::from(b.0.token())) + if let Some(ref extractor) = self.token_extractor { + extractor(h) + } else { + match &self.jwt_source { + layer::JwtSource::AuthorizationHeader => { + let bearer_o: Option> = h.typed_get(); + bearer_o.map(|b| String::from(b.0.token())) + } + layer::JwtSource::Cookie(name) => h + .typed_get::() + .and_then(|c| c.get(name.as_str()).map(String::from)), } - layer::JwtSource::Cookie(name) => h - .typed_get::() - .and_then(|c| c.get(name.as_str()).map(String::from)), } } } @@ -334,6 +352,7 @@ mod tests { Validation::new(), JwtSource::AuthorizationHeader, None, + None, ) .await .unwrap(); @@ -360,6 +379,7 @@ mod tests { Validation::new(), JwtSource::AuthorizationHeader, None, + None, ) .await .unwrap(); @@ -376,6 +396,7 @@ mod tests { Validation::new(), JwtSource::AuthorizationHeader, None, + None, ) .await .unwrap(); @@ -389,6 +410,7 @@ mod tests { Validation::new(), JwtSource::AuthorizationHeader, None, + None, ) .await .unwrap(); @@ -402,6 +424,7 @@ mod tests { Validation::new(), JwtSource::AuthorizationHeader, None, + None, ) .await .unwrap(); @@ -415,6 +438,7 @@ mod tests { Validation::new(), JwtSource::AuthorizationHeader, None, + None, ) .await .unwrap(); @@ -441,6 +465,7 @@ mod tests { Validation::new(), JwtSource::AuthorizationHeader, None, + None, ) .await .unwrap(); @@ -454,6 +479,7 @@ mod tests { Validation::new(), JwtSource::AuthorizationHeader, None, + None, ) .await .unwrap(); @@ -467,6 +493,7 @@ mod tests { Validation::new(), JwtSource::AuthorizationHeader, None, + None, ) .await .unwrap(); @@ -483,6 +510,7 @@ mod tests { Validation::new(), JwtSource::AuthorizationHeader, None, + None, ) .await; println!("{:?}", a.as_ref().err()); @@ -498,6 +526,7 @@ mod tests { Validation::default(), JwtSource::AuthorizationHeader, None, + None, ) .await; println!("{:?}", a.as_ref().err()); @@ -513,6 +542,7 @@ mod tests { Validation::default(), JwtSource::AuthorizationHeader, None, + None, ) .await; println!("{:?}", a.as_ref().err()); diff --git a/jwt-authorizer/src/builder.rs b/jwt-authorizer/src/builder.rs index 5f8799d..9f17979 100644 --- a/jwt-authorizer/src/builder.rs +++ b/jwt-authorizer/src/builder.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use serde::de::DeserializeOwned; use crate::{ - authorizer::{ClaimsCheckerFn, KeySourceType}, + authorizer::{ClaimsCheckerFn, KeySourceType, TokenExtractorFn}, error::InitError, layer::{AuthorizationLayer, JwtSource}, Authorizer, Refresh, RefreshStrategy, RegisteredClaims, Validation, @@ -24,6 +24,7 @@ where claims_checker: Option>, validation: Option, jwt_source: JwtSource, + token_extractor: Option, http_client: Option, } @@ -43,6 +44,7 @@ where claims_checker: None, validation: None, jwt_source: JwtSource::AuthorizationHeader, + token_extractor: None, http_client: None, } } @@ -55,6 +57,7 @@ where claims_checker: None, validation: None, jwt_source: JwtSource::AuthorizationHeader, + token_extractor: None, http_client: None, } } @@ -66,6 +69,7 @@ where claims_checker: None, validation: None, jwt_source: JwtSource::AuthorizationHeader, + token_extractor: None, http_client: None, } } @@ -77,6 +81,7 @@ where claims_checker: None, validation: None, jwt_source: JwtSource::AuthorizationHeader, + token_extractor: None, http_client: None, } } @@ -89,6 +94,7 @@ where claims_checker: None, validation: None, jwt_source: JwtSource::AuthorizationHeader, + token_extractor: None, http_client: None, } } @@ -101,6 +107,7 @@ where claims_checker: None, validation: None, jwt_source: JwtSource::AuthorizationHeader, + token_extractor: None, http_client: None, } } @@ -113,6 +120,7 @@ where claims_checker: None, validation: None, jwt_source: JwtSource::AuthorizationHeader, + token_extractor: None, http_client: None, } } @@ -125,6 +133,7 @@ where claims_checker: None, validation: None, jwt_source: JwtSource::AuthorizationHeader, + token_extractor: None, http_client: None, } } @@ -137,6 +146,7 @@ where claims_checker: None, validation: None, jwt_source: JwtSource::AuthorizationHeader, + token_extractor: None, http_client: None, } } @@ -149,6 +159,7 @@ where claims_checker: None, validation: None, jwt_source: JwtSource::AuthorizationHeader, + token_extractor: None, http_client: None, } } @@ -161,6 +172,7 @@ where claims_checker: None, validation: None, jwt_source: JwtSource::AuthorizationHeader, + token_extractor: None, http_client: None, } } @@ -212,6 +224,15 @@ where self } + /// configures the token extractor function + /// + /// (default: None) + pub fn token_extractor(mut self, token_extractor: TokenExtractorFn) -> AuthorizerBuilder { + self.token_extractor = Some(token_extractor); + + self + } + /// provide a custom http client for oicd requests /// if not called, uses a default configured client /// @@ -233,6 +254,7 @@ where self.refresh, val, self.jwt_source, + self.token_extractor, None, ) .await?, @@ -249,6 +271,7 @@ where self.refresh, val, self.jwt_source, + self.token_extractor, self.http_client, ) .await diff --git a/jwt-authorizer/tests/tests.rs b/jwt-authorizer/tests/tests.rs index 03acdec..f5e3207 100644 --- a/jwt-authorizer/tests/tests.rs +++ b/jwt-authorizer/tests/tests.rs @@ -14,7 +14,7 @@ mod tests { use http::{header, HeaderValue}; use jsonwebtoken::Algorithm; use jwt_authorizer::{ - authorizer::Authorizer, + authorizer::{Authorizer, TokenExtractorFn}, layer::{AuthorizationLayer, JwtSource}, validation::Validation, IntoLayer, JwtAuthorizer, JwtClaims, @@ -543,4 +543,65 @@ mod tests { .await; assert_eq!(response.status(), StatusCode::OK); } + + // -------------------- + // token_extractor + // --------------------- + #[tokio::test] + async fn jwt_custom_token_extractor() { + // Initialize custom token extractor + let token_extractor: TokenExtractorFn = Arc::new(Box::new(|headers| { + let Some(custom_header) = headers.get("X-Custom-Authorization") else { + return None; + }; + + let Ok(custom_header_str) = custom_header.to_str() else { + return None; + }; + + let token = custom_header_str.split("Bearer "); + + match token.last() { + Some(t) => Some(t.to_string()), + None => None, + } + })); + + // OK + let response = proteced_request_with_header( + JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem") + .validation(Validation::new().aud(&["aud1"])) + .token_extractor(token_extractor.clone()), + "X-Custom-Authorization", + &format!("Bearer {}", common::JWT_RSA1_OK), + ) + .await; + assert_eq!(response.status(), StatusCode::OK); + + // Header missing + let response = proteced_request_with_header( + JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem").token_extractor(token_extractor.clone()), + "X-Custom-Authorization", + "", + ) + .await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + assert_eq!( + response.headers().get(header::WWW_AUTHENTICATE).unwrap(), + &"Bearer error=\"invalid_token\"" + ); + + // Invalid Token + let response = proteced_request_with_header( + JwtAuthorizer::from_rsa_pem("../config/rsa-public1.pem").token_extractor(token_extractor.clone()), + "X-Custom-Authorization", + &format!("Bearer {}", common::JWT_EC2_OK), + ) + .await; + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + assert_eq!( + response.headers().get(header::WWW_AUTHENTICATE).unwrap(), + &"Bearer error=\"invalid_token\"" + ); + } } From 9f94a3e2fa6b83c0797d2a5c017f090acacfcf82 Mon Sep 17 00:00:00 2001 From: Vladislav Manchev Date: Thu, 25 Apr 2024 21:13:35 +0300 Subject: [PATCH 2/4] Fix clippy errors --- jwt-authorizer/tests/tests.rs | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/jwt-authorizer/tests/tests.rs b/jwt-authorizer/tests/tests.rs index f5e3207..c170910 100644 --- a/jwt-authorizer/tests/tests.rs +++ b/jwt-authorizer/tests/tests.rs @@ -551,9 +551,7 @@ mod tests { async fn jwt_custom_token_extractor() { // Initialize custom token extractor let token_extractor: TokenExtractorFn = Arc::new(Box::new(|headers| { - let Some(custom_header) = headers.get("X-Custom-Authorization") else { - return None; - }; + let custom_header = headers.get("X-Custom-Authorization")?; let Ok(custom_header_str) = custom_header.to_str() else { return None; @@ -561,10 +559,7 @@ mod tests { let token = custom_header_str.split("Bearer "); - match token.last() { - Some(t) => Some(t.to_string()), - None => None, - } + token.last().map(|t| t.to_string()) })); // OK From c9d39ea135323cea69854c85814c61c4648cae59 Mon Sep 17 00:00:00 2001 From: Vladislav Manchev Date: Wed, 1 Jan 2025 22:18:05 +0200 Subject: [PATCH 3/4] Fix clippy errors --- jwt-authorizer/src/jwks/key_store_manager.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jwt-authorizer/src/jwks/key_store_manager.rs b/jwt-authorizer/src/jwks/key_store_manager.rs index 86eba6a..2c08cb9 100644 --- a/jwt-authorizer/src/jwks/key_store_manager.rs +++ b/jwt-authorizer/src/jwks/key_store_manager.rs @@ -116,7 +116,7 @@ impl KeyStoreManager { .await?; ks_gard .find_alg(&header.alg) - .ok_or_else(|| AuthError::InvalidKeyAlg(header.alg))? + .ok_or(AuthError::InvalidKeyAlg(header.alg))? } else { return Err(AuthError::InvalidKeyAlg(header.alg)); } From d9756c2489d0dc2afaedca8d7fc9b863e89a027e Mon Sep 17 00:00:00 2001 From: Vladislav Manchev Date: Wed, 1 Jan 2025 22:33:49 +0200 Subject: [PATCH 4/4] Formatting --- jwt-authorizer/src/jwks/key_store_manager.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/jwt-authorizer/src/jwks/key_store_manager.rs b/jwt-authorizer/src/jwks/key_store_manager.rs index 2c08cb9..70f551e 100644 --- a/jwt-authorizer/src/jwks/key_store_manager.rs +++ b/jwt-authorizer/src/jwks/key_store_manager.rs @@ -114,9 +114,7 @@ impl KeyStoreManager { )], ) .await?; - ks_gard - .find_alg(&header.alg) - .ok_or(AuthError::InvalidKeyAlg(header.alg))? + ks_gard.find_alg(&header.alg).ok_or(AuthError::InvalidKeyAlg(header.alg))? } else { return Err(AuthError::InvalidKeyAlg(header.alg)); }