diff --git a/Cargo.toml b/Cargo.toml index 5aba3b5..07817d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,13 @@ [package] name = "firestore-db-and-auth" -version = "0.8.0" +version = "0.8.1" authors = ["David Gräff "] edition = "2021" license = "MIT" description = "This crate allows easy access to your Google Firestore DB via service account or OAuth impersonated Google Firebase Auth credentials." readme = "readme.md" keywords = ["firestore", "auth"] -categories = ["api-bindings","authentication"] +categories = ["api-bindings", "authentication"] maintenance = { status = "passively-maintained" } repository = "https://github.com/davidgraeff/firestore-db-and-auth-rs" rust-version = "1.64" @@ -15,7 +15,7 @@ rust-version = "1.64" [dependencies] bytes = "1.1" cache_control = "0.2" -reqwest = { version = "0.11", default-features = false, features = ["json", "blocking", "hyper-rustls"] } +reqwest = { version = "0.11.16", default-features = false, features = ["json"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" chrono = { version = "0.4", features = ["serde"] } @@ -38,10 +38,11 @@ optional = true # Render the readme file on doc.rs [package.metadata.docs.rs] -features = [ "external_doc", "rocket_support" ] +features = ["external_doc", "rocket_support"] [features] -default = ["rustls-tls", "unstable"] +default = ["rustls-tls", "unstable", "blocking"] +blocking = ["reqwest/blocking"] rocket_support = ["rocket"] rustls-tls = ["reqwest/rustls-tls"] default-tls = ["reqwest/default-tls"] @@ -49,6 +50,7 @@ native-tls = ["reqwest/native-tls"] native-tls-vendored = ["reqwest/native-tls-vendored"] unstable = [] external_doc = [] +wasm32 = ["ring/wasm32_unknown_unknown_js", "ring/std", "tokio/rt", "tokio/sync"] [[example]] name = "create_read_write_document" @@ -65,4 +67,4 @@ test = true [[example]] name = "rocket_http_protected_route" test = true -required-features = ["rustls-tls","rocket_support"] +required-features = ["rustls-tls", "rocket_support"] diff --git a/src/documents/list.rs b/src/documents/list.rs index 2d5e1c0..a17c327 100644 --- a/src/documents/list.rs +++ b/src/documents/list.rs @@ -6,7 +6,8 @@ use futures::{ task::{Context, Poll}, Future, }; -use std::boxed::Box; + +use std::sync::Arc; /// List all documents of a given collection. /// @@ -44,15 +45,15 @@ use std::boxed::Box; pub fn list( auth: &AUTH, collection_id: impl Into, -) -> Pin> + Send>> +) -> impl Stream> where - for<'b> T: Deserialize<'b> + 'static, - AUTH: FirebaseAuthBearer + Clone + Send + Sync + 'static, + for<'b> T: Deserialize<'b> + 'static, + AUTH: FirebaseAuthBearer + Clone + Send + Sync + 'static, { let auth = auth.clone(); let collection_id = collection_id.into(); - Box::pin(stream::unfold( + stream::unfold( ListInner { url: firebase_url(auth.project_id(), &collection_id), auth, @@ -116,7 +117,7 @@ where )), } }, - )) + ) } async fn get_new_data<'a>( diff --git a/src/errors.rs b/src/errors.rs index a0c500b..9f42f6d 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -6,6 +6,7 @@ use std::fmt; use reqwest; use reqwest::StatusCode; use serde::{Deserialize, Serialize}; +use crate::http2; /// A result type that uses [`FirebaseError`] as an error type pub type Result = std::result::Result; @@ -168,17 +169,25 @@ struct GoogleRESTApiErrorWrapper { /// Arguments: /// - response: The http requests response. Must be mutable, because the contained value will be extracted in an error case /// - context: A function that will be called in an error case that returns a context string -pub(crate) fn extract_google_api_error( - response: reqwest::blocking::Response, +pub(crate) async fn extract_google_api_error( + response: http2::Response, context: impl Fn() -> String, -) -> Result { +) -> Result { if response.status() == 200 { return Ok(response); } + let status = response.status().clone(); + + #[cfg(feature = "blocking")] + let res_text = response.text()?; + + #[cfg(not(feature = "blocking"))] + let res_text = response.text().await?; + Err(extract_google_api_error_intern( - response.status().clone(), - response.text()?, + status, + res_text, context, )) } diff --git a/src/http2.rs b/src/http2.rs new file mode 100644 index 0000000..812410e --- /dev/null +++ b/src/http2.rs @@ -0,0 +1,11 @@ +#[cfg(feature = "blocking")] +pub type Response = reqwest::blocking::Response; + +#[cfg(not(feature = "blocking"))] +pub type Response = reqwest::Response; + +#[cfg(feature = "blocking")] +pub type Client = reqwest::blocking::Client; + +#[cfg(not(feature = "blocking"))] +pub type Client = reqwest::Client; diff --git a/src/lib.rs b/src/lib.rs index 4f296ae..0cc7058 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,7 +12,9 @@ pub mod users; #[cfg(feature = "rocket_support")] pub mod rocket; +mod http2; +use async_trait::async_trait; // Forward declarations pub use credentials::Credentials; pub use jwt::JWKSet; @@ -26,7 +28,8 @@ pub use sessions::user::Session as UserSession; /// Firestore document methods in [`crate::documents`] expect an object that implements this `FirebaseAuthBearer` trait. /// /// Implement this trait for your own data structure and provide the Firestore project id and a valid access token. -#[async_trait::async_trait] +#[cfg_attr(target_family = "wasm", async_trait(?Send))] +#[cfg_attr(not(target_family = "wasm"), async_trait)] pub trait FirebaseAuthBearer { /// Return the project ID. This is required for the firebase REST API. fn project_id(&self) -> &str; diff --git a/src/sessions.rs b/src/sessions.rs index 5afca2b..4a194a2 100644 --- a/src/sessions.rs +++ b/src/sessions.rs @@ -20,9 +20,11 @@ use std::sync::Arc; use tokio::sync::RwLock; pub mod user { + use async_trait::async_trait; use super::*; use crate::dto::{OAuthResponse, SignInWithIdpRequest}; use credentials::Credentials; + use crate::http2::Client; #[inline] fn token_endpoint(v: &str) -> String { @@ -86,16 +88,13 @@ pub mod user { pub client: reqwest::Client, } - #[async_trait::async_trait] + #[cfg_attr(target_family = "wasm", async_trait(?Send))] + #[cfg_attr(not(target_family = "wasm"), async_trait)] impl super::FirebaseAuthBearer for Session { fn project_id(&self) -> &str { &self.project_id_ } - async fn access_token_unchecked(&self) -> String { - self.access_token_.read().await.clone() - } - /// Returns the current access token. /// This method will automatically refresh your access token, if it has expired. /// @@ -107,7 +106,7 @@ pub mod user { if is_expired(&jwt, 0).unwrap() { // Unwrap: the token is always valid at this point - if let Ok(response) = get_new_access_token(&self.api_key, &jwt).await { + if let Ok(response) = get_new_access_token(&self.client, &self.api_key, &jwt).await { *jwt = response.id_token.clone(); return response.id_token; } else { @@ -119,6 +118,10 @@ pub mod user { jwt.clone() } + async fn access_token_unchecked(&self) -> String { + self.access_token_.read().await.clone() + } + fn client(&self) -> &reqwest::Client { &self.client } @@ -126,15 +129,15 @@ pub mod user { /// Gets a new access token via an api_key and a refresh_token. async fn get_new_access_token( + client: &Client, api_key: &str, refresh_token: &str, ) -> Result { let request_body = vec![("grant_type", "refresh_token"), ("refresh_token", refresh_token)]; let url = refresh_to_access_endpoint(api_key); - let client = reqwest::Client::new(); let response = client.post(&url).form(&request_body).send().await?; - Ok(response.json().await?) + Ok(response.json::().await?) } #[allow(non_snake_case)] @@ -235,15 +238,16 @@ pub mod user { credentials: &Credentials, refresh_token: &str, ) -> Result { + let client = Client::new(); let r: RefreshTokenToAccessTokenResponse = - get_new_access_token(&credentials.api_key, refresh_token).await?; + get_new_access_token(&client, &credentials.api_key, refresh_token).await?; Ok(Session { user_id: r.user_id, access_token_: Arc::new(RwLock::new(r.id_token)), refresh_token: Some(r.refresh_token), project_id_: credentials.project_id.to_owned(), api_key: credentials.api_key.clone(), - client: reqwest::Client::new(), + client: client, }) } @@ -361,6 +365,7 @@ pub mod user { } pub mod session_cookie { + use crate::http2; use super::*; pub static GOOGLE_OAUTH2_URL: &str = "https://accounts.google.com/o/oauth2/token"; @@ -430,18 +435,24 @@ pub mod session_cookie { let assertion = crate::jwt::session_cookie::create_jwt_encoded(credentials, duration).await?; // Request Google Oauth2 to retrieve the access token in order to create a session cookie - let client = reqwest::blocking::Client::new(); - let response_oauth2: Oauth2ResponseDTO = client + let client = http2::Client::new(); + + let _res_oauth = client .post(GOOGLE_OAUTH2_URL) .form(&[ ("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"), ("assertion", &assertion), ]) - .send()? - .json()?; + .send(); + + #[cfg(feature = "blocking")] + let response_oauth2: Oauth2ResponseDTO = _res_oauth?.json()?; + + #[cfg(not(feature = "blocking"))] + let response_oauth2: Oauth2ResponseDTO = _res_oauth.await?.json().await?; // Create a session cookie with the access token previously retrieved - let response_session_cookie_json: CreateSessionCookieResponseDTO = client + let _res_cookie = client .post(&identitytoolkit_url(&credentials.project_id)) .bearer_auth(&response_oauth2.access_token) .json(&SessionLoginDTO { @@ -449,8 +460,13 @@ pub mod session_cookie { valid_duration: duration.num_seconds() as u64, tenant_id: None, }) - .send()? - .json()?; + .send(); + + #[cfg(feature = "blocking")] + let response_session_cookie_json: CreateSessionCookieResponseDTO = _res_cookie?.json()?; + + #[cfg(not(feature = "blocking"))] + let response_session_cookie_json: CreateSessionCookieResponseDTO = _res_cookie.await?.json().await?; Ok(response_session_cookie_json.session_cookie_jwk) } @@ -466,6 +482,8 @@ pub mod service_account { use chrono::Duration; use std::cell::RefCell; use std::ops::Deref; + use async_trait::async_trait; + use crate::http2; /// Service account session #[derive(Clone, Debug)] @@ -478,7 +496,8 @@ pub mod service_account { access_token_: Arc>, } - #[async_trait::async_trait] + #[cfg_attr(target_family = "wasm", async_trait(?Send))] + #[cfg_attr(not(target_family = "wasm"), async_trait)] impl super::FirebaseAuthBearer for Session { fn project_id(&self) -> &str { &self.credentials.project_id