diff --git a/jwt-authorizer/src/authorizer.rs b/jwt-authorizer/src/authorizer.rs index d80963c..4b9ed6c 100644 --- a/jwt-authorizer/src/authorizer.rs +++ b/jwt-authorizer/src/authorizer.rs @@ -56,7 +56,7 @@ where refresh: Option, validation: crate::validation::Validation, jwt_source: JwtSource, - http_client: Option, + http_client: Client, ) -> Result, InitError> { Ok(match key_source_type { KeySourceType::RSA(path) => { @@ -201,7 +201,7 @@ where } KeySourceType::Jwks(url) => { let jwks_url = Url::parse(url.as_str()).map_err(|e| InitError::JwksUrlError(e.to_string()))?; - let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default()); + let key_store_manager = KeyStoreManager::new(http_client, jwks_url, refresh.unwrap_or_default()); Authorizer { key_source: KeySource::KeyStoreSource(key_store_manager), claims_checker, @@ -210,10 +210,10 @@ where } } KeySourceType::Discovery(issuer_url) => { - let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url.as_str(), http_client).await?) + let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url.as_str(), &http_client).await?) .map_err(|e| InitError::JwksUrlError(e.to_string()))?; - let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default()); + let key_store_manager = KeyStoreManager::new(http_client, jwks_url, refresh.unwrap_or_default()); Authorizer { key_source: KeySource::KeyStoreSource(key_store_manager), claims_checker, @@ -318,6 +318,7 @@ where mod tests { use jsonwebtoken::{Algorithm, Header}; + use reqwest::Client; use serde_json::Value; use crate::{layer::JwtSource, validation::Validation}; @@ -333,7 +334,7 @@ mod tests { None, Validation::new(), JwtSource::AuthorizationHeader, - None, + Client::default(), ) .await .unwrap(); @@ -359,7 +360,7 @@ mod tests { None, Validation::new(), JwtSource::AuthorizationHeader, - None, + Client::default(), ) .await .unwrap(); @@ -375,7 +376,7 @@ mod tests { None, Validation::new(), JwtSource::AuthorizationHeader, - None, + Client::default(), ) .await .unwrap(); @@ -388,7 +389,7 @@ mod tests { None, Validation::new(), JwtSource::AuthorizationHeader, - None, + Client::default(), ) .await .unwrap(); @@ -401,7 +402,7 @@ mod tests { None, Validation::new(), JwtSource::AuthorizationHeader, - None, + Client::default(), ) .await .unwrap(); @@ -414,7 +415,7 @@ mod tests { None, Validation::new(), JwtSource::AuthorizationHeader, - None, + Client::default(), ) .await .unwrap(); @@ -440,7 +441,7 @@ mod tests { None, Validation::new(), JwtSource::AuthorizationHeader, - None, + Client::default(), ) .await .unwrap(); @@ -453,7 +454,7 @@ mod tests { None, Validation::new(), JwtSource::AuthorizationHeader, - None, + Client::default(), ) .await .unwrap(); @@ -466,7 +467,7 @@ mod tests { None, Validation::new(), JwtSource::AuthorizationHeader, - None, + Client::default(), ) .await .unwrap(); @@ -482,7 +483,7 @@ mod tests { None, Validation::new(), JwtSource::AuthorizationHeader, - None, + Client::default(), ) .await; println!("{:?}", a.as_ref().err()); @@ -497,7 +498,7 @@ mod tests { None, Validation::default(), JwtSource::AuthorizationHeader, - None, + Client::default(), ) .await; println!("{:?}", a.as_ref().err()); @@ -512,7 +513,7 @@ mod tests { None, Validation::default(), JwtSource::AuthorizationHeader, - None, + Client::default(), ) .await; println!("{:?}", a.as_ref().err()); diff --git a/jwt-authorizer/src/builder.rs b/jwt-authorizer/src/builder.rs index d1ec1bb..de2e3f5 100644 --- a/jwt-authorizer/src/builder.rs +++ b/jwt-authorizer/src/builder.rs @@ -233,7 +233,7 @@ where self.refresh, val, self.jwt_source, - None, + self.http_client.unwrap_or_default(), ) .await?, ); @@ -249,7 +249,7 @@ where self.refresh, val, self.jwt_source, - self.http_client, + self.http_client.unwrap_or_default(), ) .await } diff --git a/jwt-authorizer/src/jwks/key_store_manager.rs b/jwt-authorizer/src/jwks/key_store_manager.rs index 2ff3edf..04b5575 100644 --- a/jwt-authorizer/src/jwks/key_store_manager.rs +++ b/jwt-authorizer/src/jwks/key_store_manager.rs @@ -1,5 +1,5 @@ use jsonwebtoken::{jwk::JwkSet, Algorithm}; -use reqwest::Url; +use reqwest::{Client, Url}; use std::{ sync::Arc, time::{Duration, Instant}, @@ -51,6 +51,7 @@ impl Default for Refresh { #[derive(Clone)] pub struct KeyStoreManager { + http_client: Client, key_url: Url, /// in case of fail loading (error or key not found), minimal interval refresh: Refresh, @@ -67,8 +68,9 @@ pub struct KeyStore { } impl KeyStoreManager { - pub(crate) fn new(key_url: Url, refresh: Refresh) -> KeyStoreManager { + pub(crate) fn new(http_client: Client, key_url: Url, refresh: Refresh) -> KeyStoreManager { KeyStoreManager { + http_client, key_url, refresh, keystore: Arc::new(Mutex::new(KeyStore { @@ -85,7 +87,7 @@ impl KeyStoreManager { let key = match self.refresh.strategy { RefreshStrategy::Interval => { if ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) { - ks_gard.refresh(&self.key_url, &[]).await?; + ks_gard.refresh(&self.http_client, &self.key_url, &[]).await?; } ks_gard.get_key(header)? } @@ -95,7 +97,7 @@ impl KeyStoreManager { if let Some(jwk) = jwk_opt { jwk } else if ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) { - ks_gard.refresh(&self.key_url, &[("kid", kid)]).await?; + ks_gard.refresh(&self.http_client, &self.key_url, &[("kid", kid)]).await?; ks_gard.find_kid(kid).ok_or_else(|| AuthError::InvalidKid(kid.to_owned()))? } else { return Err(AuthError::InvalidKid(kid.to_owned())); @@ -107,6 +109,7 @@ impl KeyStoreManager { } else if ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) { ks_gard .refresh( + &self.http_client, &self.key_url, &[( "alg", @@ -127,7 +130,7 @@ impl KeyStoreManager { // if jwks endpoint is down for the loading, respect retry_interval && ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) { - ks_gard.refresh(&self.key_url, &[]).await?; + ks_gard.refresh(&self.http_client, &self.key_url, &[]).await?; } ks_gard.get_key(header)? } @@ -151,8 +154,8 @@ impl KeyStore { } } - async fn refresh(&mut self, key_url: &Url, qparam: &[(&str, &str)]) -> Result<(), AuthError> { - reqwest::Client::new() + async fn refresh(&mut self, http_client: &Client, key_url: &Url, qparam: &[(&str, &str)]) -> Result<(), AuthError> { + http_client .get(key_url.as_ref()) .query(qparam) .send() @@ -216,7 +219,7 @@ mod tests { use jsonwebtoken::Algorithm; use jsonwebtoken::{jwk::Jwk, Header}; - use reqwest::Url; + use reqwest::{Client, Url}; use wiremock::{ matchers::{method, path}, Mock, MockServer, ResponseTemplate, @@ -366,6 +369,7 @@ mod tests { mock_jwks_response_once(&mock_server, JWK_ED01).await; let ksm = KeyStoreManager::new( + Client::default(), Url::parse(&mock_server.uri()).unwrap(), Refresh { strategy: RefreshStrategy::Interval, @@ -413,6 +417,7 @@ mod tests { mock_jwks_response_once(&mock_server, JWK_ED01).await; let ksm = KeyStoreManager::new( + Client::default(), Url::parse(&mock_server.uri()).unwrap(), Refresh { strategy: RefreshStrategy::KeyNotFound, @@ -472,6 +477,7 @@ mod tests { mock_jwks_response_once(&mock_server, JWK_ED01).await; let ksm = KeyStoreManager::new( + Client::default(), Url::parse(&mock_server.uri()).unwrap(), Refresh { strategy: RefreshStrategy::NoRefresh, diff --git a/jwt-authorizer/src/oidc.rs b/jwt-authorizer/src/oidc.rs index 567bdd2..7691d7b 100644 --- a/jwt-authorizer/src/oidc.rs +++ b/jwt-authorizer/src/oidc.rs @@ -20,9 +20,7 @@ fn discovery_url(issuer: &str) -> Result { Ok(url) } -pub async fn discover_jwks(issuer: &str, client: Option) -> Result { - let client = client.unwrap_or_default(); - +pub async fn discover_jwks(issuer: &str, client: &Client) -> Result { client .get(discovery_url(issuer)?) .send()