Skip to content

Commit 77efaf1

Browse files
djcrami3l
authored andcommitted
download: merge integration test files
1 parent ca72ada commit 77efaf1

File tree

5 files changed

+365
-371
lines changed

5 files changed

+365
-371
lines changed

download/tests/all.rs

+365
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
use std::convert::Infallible;
2+
use std::fs;
3+
use std::io;
4+
use std::net::SocketAddr;
5+
use std::path::Path;
6+
use std::sync::mpsc::{Sender, channel};
7+
use std::thread;
8+
9+
use http_body_util::Full;
10+
use hyper::Request;
11+
use hyper::body::Bytes;
12+
use hyper::server::conn::http1;
13+
use hyper::service::service_fn;
14+
use tempfile::TempDir;
15+
16+
#[cfg(feature = "curl-backend")]
17+
mod curl {
18+
use std::sync::Mutex;
19+
use std::sync::atomic::{AtomicBool, Ordering};
20+
21+
use url::Url;
22+
23+
use super::{serve_file, tmp_dir, write_file};
24+
use download::*;
25+
26+
#[tokio::test]
27+
async fn partially_downloaded_file_gets_resumed_from_byte_offset() {
28+
let tmpdir = tmp_dir();
29+
let from_path = tmpdir.path().join("download-source");
30+
write_file(&from_path, "xxx45");
31+
32+
let target_path = tmpdir.path().join("downloaded");
33+
write_file(&target_path, "123");
34+
35+
let from_url = Url::from_file_path(&from_path).unwrap();
36+
Backend::Curl
37+
.download_to_path(&from_url, &target_path, true, None)
38+
.await
39+
.expect("Test download failed");
40+
41+
assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345");
42+
}
43+
44+
#[tokio::test]
45+
async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() {
46+
let tmpdir = tmp_dir();
47+
let target_path = tmpdir.path().join("downloaded");
48+
write_file(&target_path, "123");
49+
50+
let addr = serve_file(b"xxx45".to_vec());
51+
52+
let from_url = format!("http://{addr}").parse().unwrap();
53+
54+
let callback_partial = AtomicBool::new(false);
55+
let callback_len = Mutex::new(None);
56+
let received_in_callback = Mutex::new(Vec::new());
57+
58+
Backend::Curl
59+
.download_to_path(
60+
&from_url,
61+
&target_path,
62+
true,
63+
Some(&|msg| {
64+
match msg {
65+
Event::ResumingPartialDownload => {
66+
assert!(!callback_partial.load(Ordering::SeqCst));
67+
callback_partial.store(true, Ordering::SeqCst);
68+
}
69+
Event::DownloadContentLengthReceived(len) => {
70+
let mut flag = callback_len.lock().unwrap();
71+
assert!(flag.is_none());
72+
*flag = Some(len);
73+
}
74+
Event::DownloadDataReceived(data) => {
75+
for b in data.iter() {
76+
received_in_callback.lock().unwrap().push(*b);
77+
}
78+
}
79+
}
80+
81+
Ok(())
82+
}),
83+
)
84+
.await
85+
.expect("Test download failed");
86+
87+
assert!(callback_partial.into_inner());
88+
assert_eq!(*callback_len.lock().unwrap(), Some(5));
89+
let observed_bytes = received_in_callback.into_inner().unwrap();
90+
assert_eq!(observed_bytes, vec![b'1', b'2', b'3', b'4', b'5']);
91+
assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345");
92+
}
93+
}
94+
95+
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
96+
mod reqwest {
97+
use std::env::{remove_var, set_var};
98+
use std::error::Error;
99+
use std::net::TcpListener;
100+
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
101+
use std::sync::{LazyLock, Mutex};
102+
use std::thread;
103+
use std::time::Duration;
104+
105+
use env_proxy::for_url;
106+
use reqwest::{Client, Proxy};
107+
use url::Url;
108+
109+
use super::{serve_file, tmp_dir, write_file};
110+
use download::{Backend, Event, TlsBackend};
111+
112+
static SERIALISE_TESTS: LazyLock<tokio::sync::Mutex<()>> =
113+
LazyLock::new(|| tokio::sync::Mutex::new(()));
114+
115+
unsafe fn scrub_env() {
116+
unsafe {
117+
remove_var("http_proxy");
118+
remove_var("https_proxy");
119+
remove_var("HTTPS_PROXY");
120+
remove_var("ftp_proxy");
121+
remove_var("FTP_PROXY");
122+
remove_var("all_proxy");
123+
remove_var("ALL_PROXY");
124+
remove_var("no_proxy");
125+
remove_var("NO_PROXY");
126+
}
127+
}
128+
129+
// Tests for correctly retrieving the proxy (host, port) tuple from $https_proxy
130+
#[tokio::test]
131+
async fn read_basic_proxy_params() {
132+
let _guard = SERIALISE_TESTS.lock().await;
133+
// SAFETY: We are setting environment variables when `SERIALISE_TESTS` is locked,
134+
// and those environment variables in question are not relevant elsewhere in the test suite.
135+
unsafe {
136+
scrub_env();
137+
set_var("https_proxy", "http://proxy.example.com:8080");
138+
}
139+
let u = Url::parse("https://www.example.org").ok().unwrap();
140+
assert_eq!(
141+
for_url(&u).host_port(),
142+
Some(("proxy.example.com".to_string(), 8080))
143+
);
144+
}
145+
146+
// Tests to verify if socks feature is available and being used
147+
#[tokio::test]
148+
async fn socks_proxy_request() {
149+
static CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
150+
let _guard = SERIALISE_TESTS.lock().await;
151+
152+
// SAFETY: We are setting environment variables when `SERIALISE_TESTS` is locked,
153+
// and those environment variables in question are not relevant elsewhere in the test suite.
154+
unsafe {
155+
scrub_env();
156+
set_var("all_proxy", "socks5://127.0.0.1:1080");
157+
}
158+
159+
thread::spawn(move || {
160+
let listener = TcpListener::bind("127.0.0.1:1080").unwrap();
161+
let incoming = listener.incoming();
162+
for _ in incoming {
163+
CALL_COUNT.fetch_add(1, Ordering::SeqCst);
164+
}
165+
});
166+
167+
let env_proxy = |url: &Url| for_url(url).to_url();
168+
let url = Url::parse("http://192.168.0.1/").unwrap();
169+
170+
let client = Client::builder()
171+
// HACK: set `pool_max_idle_per_host` to `0` to avoid an issue in the underlying
172+
// `hyper` library that causes the `reqwest` client to hang in some cases.
173+
// See <https://github.com/hyperium/hyper/issues/2312> for more details.
174+
.pool_max_idle_per_host(0)
175+
.proxy(Proxy::custom(env_proxy))
176+
.timeout(Duration::from_secs(1))
177+
.build()
178+
.unwrap();
179+
let res = client.get(url.as_str()).send().await;
180+
181+
if let Err(e) = res {
182+
let s = e.source().unwrap();
183+
assert!(
184+
s.to_string().contains("client error (Connect)"),
185+
"Expected socks connect error, got: {s}",
186+
);
187+
assert_eq!(CALL_COUNT.load(Ordering::SeqCst), 1);
188+
} else {
189+
panic!("Socks proxy was ignored")
190+
}
191+
}
192+
193+
#[tokio::test]
194+
async fn resume_partial_from_file_url() {
195+
let tmpdir = tmp_dir();
196+
let from_path = tmpdir.path().join("download-source");
197+
write_file(&from_path, "xxx45");
198+
199+
let target_path = tmpdir.path().join("downloaded");
200+
write_file(&target_path, "123");
201+
202+
let from_url = Url::from_file_path(&from_path).unwrap();
203+
Backend::Reqwest(TlsBackend::NativeTls)
204+
.download_to_path(&from_url, &target_path, true, None)
205+
.await
206+
.expect("Test download failed");
207+
208+
assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345");
209+
}
210+
211+
#[tokio::test]
212+
async fn callback_gets_all_data_as_if_the_download_happened_all_at_once() {
213+
let tmpdir = tmp_dir();
214+
let target_path = tmpdir.path().join("downloaded");
215+
write_file(&target_path, "123");
216+
217+
let addr = serve_file(b"xxx45".to_vec());
218+
219+
let from_url = format!("http://{addr}").parse().unwrap();
220+
221+
let callback_partial = AtomicBool::new(false);
222+
let callback_len = Mutex::new(None);
223+
let received_in_callback = Mutex::new(Vec::new());
224+
225+
Backend::Reqwest(TlsBackend::NativeTls)
226+
.download_to_path(
227+
&from_url,
228+
&target_path,
229+
true,
230+
Some(&|msg| {
231+
match msg {
232+
Event::ResumingPartialDownload => {
233+
assert!(!callback_partial.load(Ordering::SeqCst));
234+
callback_partial.store(true, Ordering::SeqCst);
235+
}
236+
Event::DownloadContentLengthReceived(len) => {
237+
let mut flag = callback_len.lock().unwrap();
238+
assert!(flag.is_none());
239+
*flag = Some(len);
240+
}
241+
Event::DownloadDataReceived(data) => {
242+
for b in data.iter() {
243+
received_in_callback.lock().unwrap().push(*b);
244+
}
245+
}
246+
}
247+
248+
Ok(())
249+
}),
250+
)
251+
.await
252+
.expect("Test download failed");
253+
254+
assert!(callback_partial.into_inner());
255+
assert_eq!(*callback_len.lock().unwrap(), Some(5));
256+
let observed_bytes = received_in_callback.into_inner().unwrap();
257+
assert_eq!(observed_bytes, vec![b'1', b'2', b'3', b'4', b'5']);
258+
assert_eq!(std::fs::read_to_string(&target_path).unwrap(), "12345");
259+
}
260+
}
261+
262+
pub fn tmp_dir() -> TempDir {
263+
tempfile::Builder::new()
264+
.prefix("rustup-download-test-")
265+
.tempdir()
266+
.expect("creating tempdir for test")
267+
}
268+
269+
pub fn write_file(path: &Path, contents: &str) {
270+
let mut file = fs::OpenOptions::new()
271+
.write(true)
272+
.truncate(true)
273+
.create(true)
274+
.open(path)
275+
.expect("writing test data");
276+
277+
io::Write::write_all(&mut file, contents.as_bytes()).expect("writing test data");
278+
279+
file.sync_data().expect("writing test data");
280+
}
281+
282+
// A dead simple hyper server implementation.
283+
// For more info, see:
284+
// https://hyper.rs/guides/1/server/hello-world/
285+
async fn run_server(addr_tx: Sender<SocketAddr>, addr: SocketAddr, contents: Vec<u8>) {
286+
let svc = service_fn(move |req: Request<hyper::body::Incoming>| {
287+
let contents = contents.clone();
288+
async move {
289+
let res = serve_contents(req, contents);
290+
Ok::<_, Infallible>(res)
291+
}
292+
});
293+
294+
let listener = tokio::net::TcpListener::bind(&addr)
295+
.await
296+
.expect("can not bind");
297+
298+
let addr = listener.local_addr().unwrap();
299+
addr_tx.send(addr).unwrap();
300+
301+
loop {
302+
let (stream, _) = listener
303+
.accept()
304+
.await
305+
.expect("could not accept connection");
306+
let io = hyper_util::rt::TokioIo::new(stream);
307+
308+
let svc = svc.clone();
309+
tokio::spawn(async move {
310+
if let Err(err) = http1::Builder::new().serve_connection(io, svc).await {
311+
eprintln!("failed to serve connection: {:?}", err);
312+
}
313+
});
314+
}
315+
}
316+
317+
pub fn serve_file(contents: Vec<u8>) -> SocketAddr {
318+
let addr = ([127, 0, 0, 1], 0).into();
319+
let (addr_tx, addr_rx) = channel();
320+
321+
thread::spawn(move || {
322+
let server = run_server(addr_tx, addr, contents);
323+
let rt = tokio::runtime::Runtime::new().expect("could not creating Runtime");
324+
rt.block_on(server);
325+
});
326+
327+
let addr = addr_rx.recv();
328+
addr.unwrap()
329+
}
330+
331+
fn serve_contents(
332+
req: hyper::Request<hyper::body::Incoming>,
333+
contents: Vec<u8>,
334+
) -> hyper::Response<Full<Bytes>> {
335+
let mut range_header = None;
336+
let (status, body) = if let Some(range) = req.headers().get(hyper::header::RANGE) {
337+
// extract range "bytes={start}-"
338+
let range = range.to_str().expect("unexpected Range header");
339+
assert!(range.starts_with("bytes="));
340+
let range = range.trim_start_matches("bytes=");
341+
assert!(range.ends_with('-'));
342+
let range = range.trim_end_matches('-');
343+
assert_eq!(range.split('-').count(), 1);
344+
let start: u64 = range.parse().expect("unexpected Range header");
345+
346+
range_header = Some(format!("bytes {}-{len}/{len}", start, len = contents.len()));
347+
(
348+
hyper::StatusCode::PARTIAL_CONTENT,
349+
contents[start as usize..].to_vec(),
350+
)
351+
} else {
352+
(hyper::StatusCode::OK, contents)
353+
};
354+
355+
let mut res = hyper::Response::builder()
356+
.status(status)
357+
.header(hyper::header::CONTENT_LENGTH, body.len())
358+
.body(Full::new(Bytes::from(body)))
359+
.unwrap();
360+
if let Some(range) = range_header {
361+
res.headers_mut()
362+
.insert(hyper::header::CONTENT_RANGE, range.parse().unwrap());
363+
}
364+
res
365+
}

0 commit comments

Comments
 (0)