Skip to content

Add async/await support #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: master
Choose a base branch
from
7 changes: 4 additions & 3 deletions google_api_auth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ default = []
with-yup-oauth2 = ["yup-oauth2", "tokio"]

[dependencies]
yup-oauth2 = { version = "^4.1", optional = true }
tokio = { version = "0.2", optional = true }
hyper = "^0.13"
yup-oauth2 = { version = "5", optional = true }
tokio = { version = "1", optional = true }
hyper = "0.14"
async-trait = "0.1"
3 changes: 2 additions & 1 deletion google_api_auth/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
/// client libraries to retrieve access tokens when making http requests. This
/// library optionally provides a variety of implementations, but users are also
/// free to implement whatever logic they want for retrieving a token.
#[async_trait::async_trait]
pub trait GetAccessToken: ::std::fmt::Debug + Send + Sync {
fn access_token(&self) -> Result<String, Box<dyn ::std::error::Error + Send + Sync>>;
async fn access_token(&self) -> Result<String, Box<dyn ::std::error::Error + Send + Sync>>;
}

impl<T> From<T> for Box<dyn GetAccessToken>
Expand Down
7 changes: 3 additions & 4 deletions google_api_auth/src/yup_oauth2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ impl<T> ::std::fmt::Debug for YupAuthenticator<T> {
}
}

#[async_trait::async_trait]
impl<C> crate::GetAccessToken for YupAuthenticator<C>
where
C: Connect + Clone + Send + Sync + 'static,
{
fn access_token(&self) -> Result<String, Box<dyn ::std::error::Error + Send + Sync>> {
let fut = self.auth.token(&self.scopes);
let mut runtime = ::tokio::runtime::Runtime::new().expect("unable to start tokio runtime");
Ok(runtime.block_on(fut)?.as_str().to_string())
async fn access_token(&self) -> Result<String, Box<dyn ::std::error::Error + Send + Sync>> {
Ok(self.auth.token(&self.scopes).await?.as_str().to_string())
}
}

Expand Down
15 changes: 6 additions & 9 deletions google_rest_api_generator/gen_include/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub enum Error {
reqwest_err: ::reqwest::Error,
body: Option<String>,
},
IO(std::io::Error),
Other(Box<dyn::std::error::Error + Send + Sync>),
}

Expand All @@ -15,6 +16,7 @@ impl Error {
Error::OAuth2(_) => None,
Error::JSON(err) => Some(err),
Error::Reqwest { .. } => None,
Error::IO(_) => None,
Error::Other(_) => None,
}
}
Expand All @@ -32,6 +34,7 @@ impl ::std::fmt::Display for Error {
}
Ok(())
}
Error::IO(err) => write!(f, "IO Error: {}", err),
Error::Other(err) => write!(f, "Uknown Error: {}", err),
}
}
Expand All @@ -54,14 +57,8 @@ impl From<::reqwest::Error> for Error {
}
}

/// Check the response to see if the status code represents an error. If so
/// convert it into the Reqwest variant of Error.
fn error_from_response(response: ::reqwest::blocking::Response) -> Result<::reqwest::blocking::Response, Error> {
match response.error_for_status_ref() {
Err(reqwest_err) => {
let body = response.text().ok();
Err(Error::Reqwest { reqwest_err, body })
}
Ok(_) => Ok(response),
impl From<std::io::Error> for Error {
fn from(err: std::io::Error) -> Error {
Error::IO(err)
}
}
44 changes: 28 additions & 16 deletions google_rest_api_generator/gen_include/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub mod iter {
T: ::serde::de::DeserializeOwned;
}

pub struct PageIter<M, T>{
pub struct PageIter<M, T> {
pub method: M,
pub finished: bool,
pub _phantom: ::std::marker::PhantomData<T>,
Expand All @@ -18,7 +18,7 @@ pub mod iter {
T: ::serde::de::DeserializeOwned,
{
pub(crate) fn new(method: M) -> Self {
PageIter{
PageIter {
method,
finished: false,
_phantom: ::std::marker::PhantomData,
Expand All @@ -37,24 +37,30 @@ pub mod iter {
if self.finished {
return None;
}
let paginated_result: ::serde_json::Map<String, ::serde_json::Value> = match self.method.execute() {
Ok(r) => r,
Err(err) => return Some(Err(err)),
};
if let Some(next_page_token) = paginated_result.get("nextPageToken").and_then(|t| t.as_str()) {
let paginated_result: ::serde_json::Map<String, ::serde_json::Value> =
match self.method.execute() {
Ok(r) => r,
Err(err) => return Some(Err(err)),
};
if let Some(next_page_token) = paginated_result
.get("nextPageToken")
.and_then(|t| t.as_str())
{
self.method.set_page_token(next_page_token.to_owned());
} else {
self.finished = true;
}

Some(match ::serde_json::from_value(::serde_json::Value::Object(paginated_result)) {
Ok(resp) => Ok(resp),
Err(err) => Err(err.into())
})
Some(
match ::serde_json::from_value(::serde_json::Value::Object(paginated_result)) {
Ok(resp) => Ok(resp),
Err(err) => Err(err.into()),
},
)
}
}

pub struct PageItemIter<M, T>{
pub struct PageItemIter<M, T> {
items_field: &'static str,
page_iter: PageIter<M, ::serde_json::Map<String, ::serde_json::Value>>,
items: ::std::vec::IntoIter<T>,
Expand All @@ -66,7 +72,7 @@ pub mod iter {
T: ::serde::de::DeserializeOwned,
{
pub(crate) fn new(method: M, items_field: &'static str) -> Self {
PageItemIter{
PageItemIter {
items_field,
page_iter: PageIter::new(method),
items: Vec::new().into_iter(),
Expand All @@ -92,10 +98,16 @@ pub mod iter {
None => return None,
Some(Err(err)) => return Some(Err(err)),
Some(Ok(next_page)) => {
let mut next_page: ::serde_json::Map<String, ::serde_json::Value> = next_page;
let mut next_page: ::serde_json::Map<String, ::serde_json::Value> =
next_page;
let items_array = match next_page.remove(self.items_field) {
Some(items) => items,
None => return Some(Err(crate::Error::Other(format!("no {} field found in iter response", self.items_field).into()))),
None => {
return Some(Err(crate::Error::Other(
format!("no {} field found in iter response", self.items_field)
.into(),
)))
}
};
let items_vec: Result<Vec<T>, _> = ::serde_json::from_value(items_array);
match items_vec {
Expand All @@ -107,4 +119,4 @@ pub mod iter {
}
}
}
}
}
42 changes: 30 additions & 12 deletions google_rest_api_generator/gen_include/multipart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ mod multipart {

pub(crate) struct Part {
content_type: ::mime::Mime,
body: Box<dyn::std::io::Read + Send>,
body: Box<dyn futures::io::AsyncRead + std::marker::Unpin + Send>,
}

impl Part {
pub(crate) fn new(
content_type: ::mime::Mime,
body: Box<dyn::std::io::Read + Send>,
body: Box<dyn futures::io::AsyncRead + std::marker::Unpin + Send>,
) -> Part {
Part { content_type, body }
}
Expand All @@ -52,26 +52,32 @@ mod multipart {
pub(crate) struct RelatedMultiPartReader {
state: RelatedMultiPartReaderState,
boundary: String,
next_body: Option<Box<dyn::std::io::Read + Send>>,
next_body: Option<Box<dyn futures::io::AsyncRead + std::marker::Unpin + Send>>,
parts: std::vec::IntoIter<Part>,
}

enum RelatedMultiPartReaderState {
WriteBoundary {
start: usize, boundary: String,
start: usize,
boundary: String,
},
WriteContentType {
start: usize,
content_type: Vec<u8>,
},
WriteBody {
body: Box<dyn::std::io::Read + Send>,
body: Box<dyn futures::io::AsyncRead + std::marker::Unpin + Send>,
},
}

impl ::std::io::Read for RelatedMultiPartReader {
fn read(&mut self, buf: &mut [u8]) -> ::std::io::Result<usize> {
impl futures::io::AsyncRead for RelatedMultiPartReader {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
ctx: &mut futures::task::Context,
buf: &mut [u8],
) -> futures::task::Poll<Result<usize, futures::io::Error>> {
use RelatedMultiPartReaderState::*;

let mut bytes_written: usize = 0;
loop {
let rem_buf = &mut buf[bytes_written..];
Expand All @@ -90,8 +96,11 @@ mod multipart {
self.next_body = Some(next_part.body);
self.state = WriteContentType {
start: 0,
content_type: format!("Content-Type: {}\r\n\r\n", next_part.content_type)
.into_bytes(),
content_type: format!(
"Content-Type: {}\r\n\r\n",
next_part.content_type
)
.into_bytes(),
};
} else {
break;
Expand All @@ -101,7 +110,8 @@ mod multipart {
start,
content_type,
} => {
let bytes_to_copy = std::cmp::min(content_type.len() - *start, rem_buf.len());
let bytes_to_copy =
std::cmp::min(content_type.len() - *start, rem_buf.len());
rem_buf[..bytes_to_copy]
.copy_from_slice(&content_type[*start..*start + bytes_to_copy]);
*start += bytes_to_copy;
Expand All @@ -115,7 +125,14 @@ mod multipart {
}
}
WriteBody { body } => {
let written = body.read(rem_buf)?;
let body = std::pin::Pin::new(body);
let written = match futures::io::AsyncRead::poll_read(body, ctx, rem_buf) {
futures::task::Poll::Ready(Ok(n)) => n,
futures::task::Poll::Ready(Err(err)) => {
return futures::task::Poll::Ready(Err(err));
}
futures::task::Poll::Pending => return futures::task::Poll::Pending,
};
bytes_written += written;
if written == 0 {
self.state = WriteBoundary {
Expand All @@ -128,7 +145,8 @@ mod multipart {
}
}
}
Ok(bytes_written)

futures::task::Poll::Ready(Ok(bytes_written))
}
}

Expand Down
14 changes: 7 additions & 7 deletions google_rest_api_generator/gen_include/resumable_upload.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
pub struct ResumableUpload {
reqwest: ::reqwest::blocking::Client,
reqwest: ::reqwest::Client,
url: String,
progress: Option<i64>,
}

impl ResumableUpload {
pub fn new(reqwest: ::reqwest::blocking::Client, url: String) -> Self {
pub fn new(reqwest: ::reqwest::Client, url: String) -> Self {
ResumableUpload {
reqwest,
url,
Expand All @@ -17,7 +17,7 @@ impl ResumableUpload {
&self.url
}

pub fn upload<R>(&mut self, mut reader: R) -> Result<(), Box<dyn::std::error::Error>>
pub async fn upload<R>(&mut self, mut reader: R) -> Result<(), Box<dyn::std::error::Error>>
where
R: ::std::io::Read + ::std::io::Seek + Send + 'static,
{
Expand All @@ -36,8 +36,8 @@ impl ResumableUpload {
::reqwest::header::CONTENT_RANGE,
format!("bytes */{}", reader_len),
);
let resp = req.send()?.error_for_status()?;
match resp.headers().get(::reqwest::header::RANGE) {
let response = req.send().await?.error_for_status()?;
match response.headers().get(::reqwest::header::RANGE) {
Some(range_header) => {
let (_, progress) = parse_range_header(range_header)
.map_err(|e| format!("invalid RANGE header: {}", e))?;
Expand All @@ -53,8 +53,8 @@ impl ResumableUpload {
let content_range = format!("bytes {}-{}/{}", progress, reader_len - 1, reader_len);
let req = self.reqwest.request(::reqwest::Method::PUT, &self.url);
let req = req.header(::reqwest::header::CONTENT_RANGE, content_range);
let req = req.body(::reqwest::blocking::Body::sized(reader, content_length));
req.send()?.error_for_status()?;
let req = req.body(::reqwest::Body::sized(reader, content_length));
req.send().await?.error_for_status()?;
Ok(())
}
}
Expand Down
27 changes: 20 additions & 7 deletions google_rest_api_generator/src/cargo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,23 @@ edition = "2018"
# for now, let's not even accidentally publish these
publish = false

[features]
default = ["rustls-tls"]

native-tls = ["reqwest/native-tls"]
rustls-tls = ["reqwest/rustls-tls"]

[dependencies]
serde = { version = "1", features = ["derive"] }
serde_json = "1"
chrono = { version = "0.4", features = ["serde"] }
reqwest = { version = "0.10", default-features = false, features = ['rustls-tls', 'blocking', 'json'] }
google_field_selector = { git = "https://github.com/google-apis-rs/generator" }
google_api_auth = { git = "https://github.com/google-apis-rs/generator" }
futures = "0.3"
google_api_auth = { git = "https://github.com/google-apis-rs/generator", branch = "refactor/async" }
google_field_selector = { git = "https://github.com/google-apis-rs/generator", branch = "refactor/async" }
mime = "0.3"
textnonce = "0.6"
percent-encoding = "2"
reqwest = { version = "0.11", default-features = false, features = ["json"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
textnonce = "1"
"#;

pub(crate) fn cargo_toml(crate_name: &str, include_bytes_dep: bool, api: &shared::Api) -> String {
Expand All @@ -30,9 +37,15 @@ pub(crate) fn cargo_toml(crate_name: &str, include_bytes_dep: bool, api: &shared
.expect("available crate version"),
);

// TODO: figure out a better way to determine if we should include stream reqwest feature
if crate_name.contains("storage") {
doc = doc.replace(r#"features = ["json"]"#, r#"features = ["stream", "json"]"#);
}

if include_bytes_dep {
doc.push_str("\n[dependencies.google_api_bytes]\n");
doc.push_str("\n\n[dependencies.google_api_bytes]\n");
doc.push_str("git = \"https://github.com/google-apis-rs/generator\"\n");
}

doc
}
Loading