Skip to content

Commit 5bd06be

Browse files
authored
Custom core Request and Response (#230)
* Request/Response * Ryan's suggestions and support for retry file and seekable stream * Mock implementation for WASM * fixed unused imports due to WASM * fixed unused imports due to hyper being temporarily disabled
1 parent b7c0a9c commit 5bd06be

22 files changed

+817
-286
lines changed

sdk/core/Cargo.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@ bytes = "1.0"
2929
hyper-rustls = { version = "0.22", optional = true }
3030
async-trait = "0.1"
3131
oauth2 = "4.0.0"
32-
reqwest = { version = "0.11", optional = true }
32+
reqwest = { version = "0.11", features = ["stream"], optional = true }
33+
rand = "0.7"
34+
dyn-clone = "1.0"
3335

3436
[dev-dependencies]
35-
tokio = "1.0"
3637
env_logger = "0.8"
38+
tokio = { version = "1.0", features = ["default"] }
3739

3840
[features]
3941
default = ["enable_reqwest"]

sdk/core/src/bytes_stream.rs

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
use bytes::Bytes;
2+
use futures::io::AsyncRead;
3+
use futures::stream::Stream;
4+
use std::pin::Pin;
5+
use std::task::Poll;
6+
7+
use crate::SeekableStream;
8+
9+
/// Convenience struct that maps a `bytes::Bytes` buffer into a stream.
10+
///
11+
/// This struct implements both `Stream` and `SeekableStream` for an
12+
/// immutable bytes buffer. It's cheap to clone but remember to `reset`
13+
/// the stream position if you clone it.
14+
#[derive(Debug, Clone)]
15+
pub struct BytesStream {
16+
bytes: Bytes,
17+
bytes_read: usize,
18+
}
19+
20+
impl BytesStream {
21+
pub fn new(bytes: impl Into<Bytes>) -> Self {
22+
Self {
23+
bytes: bytes.into(),
24+
bytes_read: 0,
25+
}
26+
}
27+
28+
/// Creates a stream that resolves immediately with no data.
29+
pub fn new_empty() -> Self {
30+
Self::new(Bytes::new())
31+
}
32+
}
33+
34+
impl From<Bytes> for BytesStream {
35+
fn from(bytes: Bytes) -> Self {
36+
Self::new(bytes)
37+
}
38+
}
39+
40+
impl Stream for BytesStream {
41+
type Item = Result<Bytes, Box<dyn std::error::Error + Send + Sync>>;
42+
43+
fn poll_next(
44+
self: Pin<&mut Self>,
45+
_cx: &mut std::task::Context<'_>,
46+
) -> Poll<Option<Self::Item>> {
47+
let mut self_mut = self.get_mut();
48+
49+
// we return all the available bytes in one call.
50+
if self_mut.bytes_read < self_mut.bytes.len() {
51+
let bytes_read = self_mut.bytes_read;
52+
self_mut.bytes_read = self_mut.bytes.len();
53+
Poll::Ready(Some(Ok(self_mut.bytes.slice(bytes_read..))))
54+
} else {
55+
Poll::Ready(None)
56+
}
57+
}
58+
}
59+
60+
#[async_trait::async_trait]
61+
impl SeekableStream for BytesStream {
62+
async fn reset(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
63+
self.bytes_read = 0;
64+
Ok(())
65+
}
66+
}
67+
68+
impl AsyncRead for BytesStream {
69+
fn poll_read(
70+
self: Pin<&mut Self>,
71+
_cx: &mut std::task::Context<'_>,
72+
buf: &mut [u8],
73+
) -> Poll<std::io::Result<usize>> {
74+
let mut self_mut = self.get_mut();
75+
76+
if self_mut.bytes_read < self_mut.bytes.len() {
77+
let bytes_read = self_mut.bytes_read;
78+
let remaining_bytes = self_mut.bytes.len() - bytes_read;
79+
80+
let bytes_to_copy = std::cmp::min(remaining_bytes, buf.len());
81+
82+
for (buf_byte, bytes_byte) in buf
83+
.iter_mut()
84+
.zip(self_mut.bytes.slice(self_mut.bytes_read..bytes_to_copy))
85+
{
86+
*buf_byte = bytes_byte;
87+
}
88+
89+
self_mut.bytes_read += bytes_to_copy;
90+
91+
Poll::Ready(Ok(bytes_to_copy))
92+
} else {
93+
Poll::Ready(Ok(0))
94+
}
95+
}
96+
}

sdk/core/src/context.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/// Pipeline execution context.
2+
///
3+
/// During a pipeline execution, context will be passed from the function starting the
4+
/// pipeline down to each pipeline policy. Contrarily to the Request, the context can be mutated
5+
/// by each pipeline policy and is not reset between retries. It can be used to pass the whole
6+
/// pipeline execution history between policies.
7+
/// For example, it could be used to signal that an execution failed because a CosmosDB endpoint is
8+
/// down and the appropriate policy should try the next one).
9+
#[derive(Clone)]
10+
pub struct Context {
11+
// Temporary hack to make sure that Context is not initializeable
12+
// Soon Context will have proper data fields
13+
_priv: (),
14+
}
15+
16+
impl Context {
17+
pub fn new() -> Self {
18+
Self { _priv: () }
19+
}
20+
}

sdk/core/src/http.rs

Lines changed: 0 additions & 44 deletions
This file was deleted.

sdk/core/src/http_client.rs

Lines changed: 103 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1-
use crate::errors::*;
1+
use crate::errors::{AzureError, UnexpectedHTTPResult};
2+
#[allow(unused_imports)]
3+
use crate::Body;
24
use async_trait::async_trait;
35
use bytes::Bytes;
6+
#[allow(unused_imports)]
7+
use futures::TryStreamExt;
48
use http::{Request, Response, StatusCode};
59
#[cfg(feature = "enable_hyper")]
10+
#[allow(unused_imports)]
611
use hyper_rustls::HttpsConnector;
712
use serde::Serialize;
813

@@ -14,6 +19,17 @@ pub trait HttpClient: Send + Sync + std::fmt::Debug {
1419
request: Request<Bytes>,
1520
) -> Result<Response<Bytes>, Box<dyn std::error::Error + Sync + Send>>;
1621

22+
/// This function will be the only one remaining in the trait as soon as the trait stabilizes.
23+
/// It will be renamed to `execute_request`. The other helper functions (ie
24+
/// `execute_request_check_status`) will be removed since the status check will be
25+
/// responsibility of another policy (not the transport one). It does not consume the request.
26+
/// Implementors are expected to clone the necessary parts of the request and pass them to the
27+
/// underlying transport.
28+
async fn execute_request2(
29+
&self,
30+
request: &crate::Request,
31+
) -> Result<crate::Response, Box<dyn std::error::Error + Sync + Send>>;
32+
1733
async fn execute_request_check_status(
1834
&self,
1935
request: Request<Bytes>,
@@ -62,85 +78,127 @@ pub trait HttpClient: Send + Sync + std::fmt::Debug {
6278
}
6379
}
6480

65-
#[cfg(feature = "enable_hyper")]
81+
// TODO: To reimplement once the Request and Response are validated.
82+
//#[cfg(feature = "enable_hyper")]
83+
//#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
84+
//#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
85+
//impl HttpClient for hyper::Client<HttpsConnector<hyper::client::HttpConnector>> {
86+
// async fn execute_request(
87+
// &self,
88+
// request: Request<Bytes>,
89+
// ) -> Result<Response<Bytes>, Box<dyn std::error::Error + Sync + Send>> {
90+
// let mut hyper_request = hyper::Request::builder()
91+
// .uri(request.uri())
92+
// .method(request.method());
93+
//
94+
// for header in request.headers() {
95+
// hyper_request = hyper_request.header(header.0, header.1);
96+
// }
97+
//
98+
// let hyper_request = hyper_request.body(hyper::Body::from(request.into_body()))?;
99+
//
100+
// let hyper_response = self.request(hyper_request).await?;
101+
//
102+
// let mut response = Response::builder()
103+
// .status(hyper_response.status())
104+
// .version(hyper_response.version());
105+
//
106+
// for (key, value) in hyper_response.headers() {
107+
// response = response.header(key, value);
108+
// }
109+
//
110+
// let response = response.body(hyper::body::to_bytes(hyper_response.into_body()).await?)?;
111+
//
112+
// Ok(response)
113+
// }
114+
//}
115+
116+
#[cfg(feature = "enable_reqwest")]
66117
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
67118
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
68-
impl HttpClient for hyper::Client<HttpsConnector<hyper::client::HttpConnector>> {
119+
impl HttpClient for reqwest::Client {
69120
async fn execute_request(
70121
&self,
71122
request: Request<Bytes>,
72123
) -> Result<Response<Bytes>, Box<dyn std::error::Error + Sync + Send>> {
73-
let mut hyper_request = hyper::Request::builder()
74-
.uri(request.uri())
75-
.method(request.method());
76-
77-
for header in request.headers() {
78-
hyper_request = hyper_request.header(header.0, header.1);
124+
let mut reqwest_request = self.request(
125+
request.method().clone(),
126+
url::Url::parse(&request.uri().to_string()).unwrap(),
127+
);
128+
for (header, value) in request.headers() {
129+
reqwest_request = reqwest_request.header(header, value);
79130
}
80131

81-
let hyper_request = hyper_request.body(hyper::Body::from(request.into_body()))?;
132+
let reqwest_request = reqwest_request.body(request.into_body()).build()?;
82133

83-
let hyper_response = self.request(hyper_request).await?;
134+
let reqwest_response = self.execute(reqwest_request).await?;
84135

85-
let mut response = Response::builder()
86-
.status(hyper_response.status())
87-
.version(hyper_response.version());
136+
let mut response = Response::builder().status(reqwest_response.status());
88137

89-
for (key, value) in hyper_response.headers() {
138+
for (key, value) in reqwest_response.headers() {
90139
response = response.header(key, value);
91140
}
92141

93-
let response = response.body(hyper::body::to_bytes(hyper_response.into_body()).await?)?;
142+
let response = response.body(reqwest_response.bytes().await?)?;
94143

95144
Ok(response)
96145
}
97-
}
98146

99-
#[cfg(feature = "enable_reqwest")]
100-
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
101-
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
102-
impl HttpClient for reqwest::Client {
103-
async fn execute_request(
147+
#[cfg(not(target_arch = "wasm32"))]
148+
async fn execute_request2(
104149
&self,
105-
request: Request<Bytes>,
106-
) -> Result<Response<Bytes>, Box<dyn std::error::Error + Sync + Send>> {
107-
let mut reqwest_request =
108-
self.request(request.method().clone(), &request.uri().to_string());
150+
request: &crate::Request,
151+
) -> Result<crate::Response, Box<dyn std::error::Error + Sync + Send>> {
152+
let mut reqwest_request = self.request(
153+
request.method().clone(),
154+
url::Url::parse(&request.uri().to_string()).unwrap(),
155+
);
109156
for header in request.headers() {
110157
reqwest_request = reqwest_request.header(header.0, header.1);
111158
}
112159

113-
let reqwest_request = reqwest_request.body(request.into_body()).build()?;
160+
// We clone the body since we need to give ownership of it to
161+
// Reqwest.
162+
let body = request.clone_body();
114163

115-
let reqwest_response = self.execute(reqwest_request).await?;
164+
let reqwest_request = match body {
165+
Body::Bytes(bytes) => reqwest_request.body(bytes).build()?,
166+
Body::SeekableStream(mut seekable_stream) => {
167+
seekable_stream.reset().await?;
116168

117-
let mut response = Response::builder().status(reqwest_response.status());
169+
reqwest_request
170+
.body(reqwest::Body::wrap_stream(seekable_stream))
171+
.build()?
172+
}
173+
};
118174

119-
if let Some(version) = get_version(&reqwest_response) {
120-
response = response.version(version);
121-
}
175+
let reqwest_response = self.execute(reqwest_request).await?;
176+
let mut response = crate::ResponseBuilder::new(reqwest_response.status());
122177

123178
for (key, value) in reqwest_response.headers() {
124-
response = response.header(key, value);
179+
response.with_header(key, value.clone());
125180
}
126181

127-
let response = response.body(reqwest_response.bytes().await?)?;
182+
let response = response.with_pinned_stream(Box::pin(
183+
reqwest_response.bytes_stream().map_err(|err| err.into()),
184+
));
128185

129186
Ok(response)
130187
}
131-
}
132188

133-
// wasm can not get the http version
134-
#[cfg(feature = "enable_reqwest")]
135-
#[cfg(target_arch = "wasm32")]
136-
fn get_version(_response: &reqwest::Response) -> Option<http::Version> {
137-
None
138-
}
189+
#[cfg(target_arch = "wasm32")]
190+
/// Stub implementation. Will remove as soon as reqwest starts
191+
/// supporting wasm.
192+
async fn execute_request2(
193+
&self,
194+
_request: &crate::Request,
195+
) -> Result<crate::Response, Box<dyn std::error::Error + Sync + Send>> {
196+
let response = crate::ResponseBuilder::new(http::StatusCode::OK);
139197

140-
#[cfg(feature = "enable_reqwest")]
141-
#[cfg(not(target_arch = "wasm32"))]
142-
fn get_version(response: &reqwest::Response) -> Option<http::Version> {
143-
Some(response.version())
198+
let response = response.with_pinned_stream(Box::pin(crate::BytesStream::new_empty()));
199+
200+
Ok(response)
201+
}
144202
}
145203

146204
/// Serialize to json

0 commit comments

Comments
 (0)