Skip to content

Commit 2cc7b24

Browse files
committed
Simplify download API abstraction
1 parent ac84cb5 commit 2cc7b24

File tree

6 files changed

+201
-194
lines changed

6 files changed

+201
-194
lines changed

src/cli/self_update.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ use same_file::Handle;
4848
use serde::{Deserialize, Serialize};
4949
use tracing::{error, info, trace, warn};
5050

51+
use crate::download::download_file;
5152
use crate::{
5253
DUP_TOOLS, TOOLS,
5354
cli::{
@@ -1134,7 +1135,7 @@ pub(crate) async fn prepare_update(process: &Process) -> Result<Option<PathBuf>>
11341135

11351136
// Download new version
11361137
info!("downloading self-update");
1137-
utils::download_file(&download_url, &setup_path, None, &|_| (), process).await?;
1138+
download_file(&download_url, &setup_path, None, &|_| (), process).await?;
11381139

11391140
// Mark as executable
11401141
utils::make_executable(&setup_path)?;
@@ -1153,7 +1154,7 @@ async fn get_available_rustup_version(process: &Process) -> Result<String> {
11531154
let release_file_url = format!("{update_root}/release-stable.toml");
11541155
let release_file_url = utils::parse_url(&release_file_url)?;
11551156
let release_file = tempdir.path().join("release-stable.toml");
1156-
utils::download_file(&release_file_url, &release_file, None, &|_| (), process).await?;
1157+
download_file(&release_file_url, &release_file, None, &|_| (), process).await?;
11571158
let release_toml_str = utils::read_file("rustup release", &release_file)?;
11581159
let release_toml = toml::from_str::<RustupManifest>(&release_toml_str)
11591160
.context("unable to parse rustup release file")?;

src/cli/self_update/windows.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use super::common;
2222
use super::{InstallOpts, install_bins, report_error};
2323
use crate::cli::{download_tracker::DownloadTracker, markdown::md};
2424
use crate::dist::TargetTriple;
25+
use crate::download::download_file;
2526
use crate::process::{Process, terminalsource::ColorableTerminal};
2627
use crate::utils::{self, Notification};
2728

@@ -276,7 +277,7 @@ pub(crate) async fn try_install_msvc(
276277
download_tracker.lock().unwrap().download_finished();
277278

278279
info!("downloading Visual Studio installer");
279-
utils::download_file(
280+
download_file(
280281
&visual_studio_url,
281282
&visual_studio,
282283
None,

src/dist/download.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use url::Url;
88

99
use crate::dist::notifications::*;
1010
use crate::dist::temp;
11+
use crate::download::download_file;
12+
use crate::download::download_file_with_resume;
1113
use crate::errors::*;
1214
use crate::process::Process;
1315
use crate::utils;
@@ -73,7 +75,7 @@ impl<'a> DownloadCfg<'a> {
7375

7476
let mut hasher = Sha256::new();
7577

76-
if let Err(e) = utils::download_file_with_resume(
78+
if let Err(e) = download_file_with_resume(
7779
url,
7880
&partial_file_path,
7981
Some(&mut hasher),
@@ -134,7 +136,7 @@ impl<'a> DownloadCfg<'a> {
134136
let hash_url = utils::parse_url(&(url.to_owned() + ".sha256"))?;
135137
let hash_file = self.tmp_cx.new_file()?;
136138

137-
utils::download_file(
139+
download_file(
138140
&hash_url,
139141
&hash_file,
140142
None,
@@ -179,7 +181,7 @@ impl<'a> DownloadCfg<'a> {
179181
let file = self.tmp_cx.new_file_with_ext("", ext)?;
180182

181183
let mut hasher = Sha256::new();
182-
utils::download_file(
184+
download_file(
183185
&url,
184186
&file,
185187
Some(&mut hasher),

src/dist/manifestation/tests.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use crate::{
2323
prefix::InstallPrefix,
2424
temp,
2525
},
26+
download::download_file,
2627
errors::RustupError,
2728
process::TestProcess,
2829
test::{
@@ -495,7 +496,7 @@ impl TestContext {
495496
// Download the dist manifest and place it into the installation prefix
496497
let manifest_url = make_manifest_url(&self.url, &self.toolchain)?;
497498
let manifest_file = self.tmp_cx.new_file()?;
498-
utils::download_file(&manifest_url, &manifest_file, None, &|_| {}, dl_cfg.process).await?;
499+
download_file(&manifest_url, &manifest_file, None, &|_| {}, dl_cfg.process).await?;
499500
let manifest_str = utils::read_file("manifest", &manifest_file)?;
500501
let manifest = Manifest::parse(&manifest_str)?;
501502

src/download/mod.rs

+189-9
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,192 @@ use std::path::Path;
66
#[cfg(any(not(feature = "curl-backend"), not(feature = "reqwest-rustls-tls"), not(feature = "reqwest-native-tls")))]
77
use anyhow::anyhow;
88
use anyhow::{Context, Result};
9+
use sha2::Sha256;
910
use thiserror::Error;
11+
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
12+
use tracing::info;
1013
use url::Url;
1114

15+
use crate::{errors::RustupError, process::Process, utils::Notification};
16+
1217
#[cfg(test)]
1318
mod tests;
1419

20+
pub(crate) async fn download_file(
21+
url: &Url,
22+
path: &Path,
23+
hasher: Option<&mut Sha256>,
24+
notify_handler: &dyn Fn(Notification<'_>),
25+
process: &Process,
26+
) -> Result<()> {
27+
download_file_with_resume(url, path, hasher, false, &notify_handler, process).await
28+
}
29+
30+
pub(crate) async fn download_file_with_resume(
31+
url: &Url,
32+
path: &Path,
33+
hasher: Option<&mut Sha256>,
34+
resume_from_partial: bool,
35+
notify_handler: &dyn Fn(Notification<'_>),
36+
process: &Process,
37+
) -> Result<()> {
38+
use crate::download::DownloadError as DEK;
39+
match download_file_(
40+
url,
41+
path,
42+
hasher,
43+
resume_from_partial,
44+
notify_handler,
45+
process,
46+
)
47+
.await
48+
{
49+
Ok(_) => Ok(()),
50+
Err(e) => {
51+
if e.downcast_ref::<std::io::Error>().is_some() {
52+
return Err(e);
53+
}
54+
let is_client_error = match e.downcast_ref::<DEK>() {
55+
// Specifically treat the bad partial range error as not our
56+
// fault in case it was something odd which happened.
57+
Some(DEK::HttpStatus(416)) => false,
58+
Some(DEK::HttpStatus(400..=499)) | Some(DEK::FileNotFound) => true,
59+
_ => false,
60+
};
61+
Err(e).with_context(|| {
62+
if is_client_error {
63+
RustupError::DownloadNotExists {
64+
url: url.clone(),
65+
path: path.to_path_buf(),
66+
}
67+
} else {
68+
RustupError::DownloadingFile {
69+
url: url.clone(),
70+
path: path.to_path_buf(),
71+
}
72+
}
73+
})
74+
}
75+
}
76+
}
77+
78+
async fn download_file_(
79+
url: &Url,
80+
path: &Path,
81+
hasher: Option<&mut Sha256>,
82+
resume_from_partial: bool,
83+
notify_handler: &dyn Fn(Notification<'_>),
84+
process: &Process,
85+
) -> Result<()> {
86+
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
87+
use crate::download::{Backend, Event, TlsBackend};
88+
use sha2::Digest;
89+
use std::cell::RefCell;
90+
91+
notify_handler(Notification::DownloadingFile(url, path));
92+
93+
let hasher = RefCell::new(hasher);
94+
95+
// This callback will write the download to disk and optionally
96+
// hash the contents, then forward the notification up the stack
97+
let callback: &dyn Fn(Event<'_>) -> anyhow::Result<()> = &|msg| {
98+
if let Event::DownloadDataReceived(data) = msg {
99+
if let Some(h) = hasher.borrow_mut().as_mut() {
100+
h.update(data);
101+
}
102+
}
103+
104+
match msg {
105+
Event::DownloadContentLengthReceived(len) => {
106+
notify_handler(Notification::DownloadContentLengthReceived(len));
107+
}
108+
Event::DownloadDataReceived(data) => {
109+
notify_handler(Notification::DownloadDataReceived(data));
110+
}
111+
Event::ResumingPartialDownload => {
112+
notify_handler(Notification::ResumingPartialDownload);
113+
}
114+
}
115+
116+
Ok(())
117+
};
118+
119+
// Download the file
120+
121+
// Keep the curl env var around for a bit
122+
let use_curl_backend = process.var_os("RUSTUP_USE_CURL").map(|it| it != "0");
123+
let use_rustls = process.var_os("RUSTUP_USE_RUSTLS").map(|it| it != "0");
124+
125+
let backend = match (use_curl_backend, use_rustls) {
126+
// If environment specifies a backend that's unavailable, error out
127+
#[cfg(not(feature = "reqwest-rustls-tls"))]
128+
(_, Some(true)) => {
129+
return Err(anyhow!(
130+
"RUSTUP_USE_RUSTLS is set, but this rustup distribution was not built with the reqwest-rustls-tls feature"
131+
));
132+
}
133+
#[cfg(not(feature = "reqwest-native-tls"))]
134+
(_, Some(false)) => {
135+
return Err(anyhow!(
136+
"RUSTUP_USE_RUSTLS is set to false, but this rustup distribution was not built with the reqwest-native-tls feature"
137+
));
138+
}
139+
#[cfg(not(feature = "curl-backend"))]
140+
(Some(true), _) => {
141+
return Err(anyhow!(
142+
"RUSTUP_USE_CURL is set, but this rustup distribution was not built with the curl-backend feature"
143+
));
144+
}
145+
146+
// Positive selections, from least preferred to most preferred
147+
#[cfg(feature = "curl-backend")]
148+
(Some(true), None) => Backend::Curl,
149+
#[cfg(feature = "reqwest-native-tls")]
150+
(_, Some(false)) => {
151+
if use_curl_backend == Some(true) {
152+
info!(
153+
"RUSTUP_USE_CURL is set and RUSTUP_USE_RUSTLS is set to off, using reqwest with native-tls"
154+
);
155+
}
156+
Backend::Reqwest(TlsBackend::NativeTls)
157+
}
158+
#[cfg(feature = "reqwest-rustls-tls")]
159+
_ => {
160+
if use_curl_backend == Some(true) {
161+
info!(
162+
"both RUSTUP_USE_CURL and RUSTUP_USE_RUSTLS are set, using reqwest with rustls"
163+
);
164+
}
165+
Backend::Reqwest(TlsBackend::Rustls)
166+
}
167+
168+
// Falling back if only one backend is available
169+
#[cfg(all(not(feature = "reqwest-rustls-tls"), feature = "reqwest-native-tls"))]
170+
_ => Backend::Reqwest(TlsBackend::NativeTls),
171+
#[cfg(all(
172+
not(feature = "reqwest-rustls-tls"),
173+
not(feature = "reqwest-native-tls"),
174+
feature = "curl-backend"
175+
))]
176+
_ => Backend::Curl,
177+
};
178+
179+
notify_handler(match backend {
180+
#[cfg(feature = "curl-backend")]
181+
Backend::Curl => Notification::UsingCurl,
182+
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
183+
Backend::Reqwest(_) => Notification::UsingReqwest,
184+
});
185+
186+
let res = backend
187+
.download_to_path(url, path, resume_from_partial, Some(callback))
188+
.await;
189+
190+
notify_handler(Notification::DownloadFinished);
191+
192+
res
193+
}
194+
15195
/// User agent header value for HTTP request.
16196
/// See: https://github.com/rust-lang/rustup/issues/2860.
17197
#[cfg(feature = "curl-backend")]
@@ -29,15 +209,15 @@ const REQWEST_RUSTLS_TLS_USER_AGENT: &str =
29209
concat!("rustup/", env!("CARGO_PKG_VERSION"), " (reqwest; rustls)");
30210

31211
#[derive(Debug, Copy, Clone)]
32-
pub enum Backend {
212+
enum Backend {
33213
#[cfg(feature = "curl-backend")]
34214
Curl,
35215
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
36216
Reqwest(TlsBackend),
37217
}
38218

39219
impl Backend {
40-
pub async fn download_to_path(
220+
async fn download_to_path(
41221
self,
42222
url: &Url,
43223
path: &Path,
@@ -171,7 +351,7 @@ impl Backend {
171351

172352
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
173353
#[derive(Debug, Copy, Clone)]
174-
pub enum TlsBackend {
354+
enum TlsBackend {
175355
#[cfg(feature = "reqwest-rustls-tls")]
176356
Rustls,
177357
#[cfg(feature = "reqwest-native-tls")]
@@ -198,7 +378,7 @@ impl TlsBackend {
198378
}
199379

200380
#[derive(Debug, Copy, Clone)]
201-
pub enum Event<'a> {
381+
enum Event<'a> {
202382
ResumingPartialDownload,
203383
/// Received the Content-Length of the to-be downloaded data.
204384
DownloadContentLengthReceived(u64),
@@ -211,7 +391,7 @@ type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> Result<()>;
211391
/// Download via libcurl; encrypt with the native (or OpenSSl) TLS
212392
/// stack via libcurl
213393
#[cfg(feature = "curl-backend")]
214-
pub mod curl {
394+
mod curl {
215395
use std::cell::RefCell;
216396
use std::str;
217397
use std::time::Duration;
@@ -222,7 +402,7 @@ pub mod curl {
222402

223403
use super::{DownloadError, Event};
224404

225-
pub fn download(
405+
pub(super) fn download(
226406
url: &Url,
227407
resume_from: u64,
228408
callback: &dyn Fn(Event<'_>) -> Result<()>,
@@ -323,7 +503,7 @@ pub mod curl {
323503
}
324504

325505
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
326-
pub mod reqwest_be {
506+
mod reqwest_be {
327507
use std::io;
328508
#[cfg(feature = "reqwest-rustls-tls")]
329509
use std::sync::Arc;
@@ -342,7 +522,7 @@ pub mod reqwest_be {
342522

343523
use super::{DownloadError, Event};
344524

345-
pub async fn download(
525+
pub(super) async fn download(
346526
url: &Url,
347527
resume_from: u64,
348528
callback: &dyn Fn(Event<'_>) -> Result<()>,
@@ -487,7 +667,7 @@ pub mod reqwest_be {
487667
}
488668

489669
#[derive(Debug, Error)]
490-
pub enum DownloadError {
670+
enum DownloadError {
491671
#[error("http request returned an unsuccessful status code: {0}")]
492672
HttpStatus(u32),
493673
#[error("file not found")]

0 commit comments

Comments
 (0)