Skip to content

Commit 6b0dde3

Browse files
committed
Simplify download API abstraction
1 parent 4b4d0bd commit 6b0dde3

File tree

6 files changed

+201
-194
lines changed

6 files changed

+201
-194
lines changed

src/cli/self_update.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ use same_file::Handle;
4848
use serde::{Deserialize, Serialize};
4949
use tracing::{error, info, trace, warn};
5050

51+
use crate::download::download_file;
5152
use crate::{
5253
DUP_TOOLS, TOOLS,
5354
cli::{
@@ -1134,7 +1135,7 @@ pub(crate) async fn prepare_update(process: &Process) -> Result<Option<PathBuf>>
11341135

11351136
// Download new version
11361137
info!("downloading self-update");
1137-
utils::download_file(&download_url, &setup_path, None, &|_| (), process).await?;
1138+
download_file(&download_url, &setup_path, None, &|_| (), process).await?;
11381139

11391140
// Mark as executable
11401141
utils::make_executable(&setup_path)?;
@@ -1153,7 +1154,7 @@ async fn get_available_rustup_version(process: &Process) -> Result<String> {
11531154
let release_file_url = format!("{update_root}/release-stable.toml");
11541155
let release_file_url = utils::parse_url(&release_file_url)?;
11551156
let release_file = tempdir.path().join("release-stable.toml");
1156-
utils::download_file(&release_file_url, &release_file, None, &|_| (), process).await?;
1157+
download_file(&release_file_url, &release_file, None, &|_| (), process).await?;
11571158
let release_toml_str = utils::read_file("rustup release", &release_file)?;
11581159
let release_toml = toml::from_str::<RustupManifest>(&release_toml_str)
11591160
.context("unable to parse rustup release file")?;

src/cli/self_update/windows.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use super::common;
2222
use super::{InstallOpts, install_bins, report_error};
2323
use crate::cli::{download_tracker::DownloadTracker, markdown::md};
2424
use crate::dist::TargetTriple;
25+
use crate::download::download_file;
2526
use crate::process::{Process, terminalsource::ColorableTerminal};
2627
use crate::utils::{self, Notification};
2728

@@ -276,7 +277,7 @@ pub(crate) async fn try_install_msvc(
276277
download_tracker.lock().unwrap().download_finished();
277278

278279
info!("downloading Visual Studio installer");
279-
utils::download_file(
280+
download_file(
280281
&visual_studio_url,
281282
&visual_studio,
282283
None,

src/dist/download.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use url::Url;
88

99
use crate::dist::notifications::*;
1010
use crate::dist::temp;
11+
use crate::download::download_file;
12+
use crate::download::download_file_with_resume;
1113
use crate::errors::*;
1214
use crate::process::Process;
1315
use crate::utils;
@@ -73,7 +75,7 @@ impl<'a> DownloadCfg<'a> {
7375

7476
let mut hasher = Sha256::new();
7577

76-
if let Err(e) = utils::download_file_with_resume(
78+
if let Err(e) = download_file_with_resume(
7779
url,
7880
&partial_file_path,
7981
Some(&mut hasher),
@@ -134,7 +136,7 @@ impl<'a> DownloadCfg<'a> {
134136
let hash_url = utils::parse_url(&(url.to_owned() + ".sha256"))?;
135137
let hash_file = self.tmp_cx.new_file()?;
136138

137-
utils::download_file(
139+
download_file(
138140
&hash_url,
139141
&hash_file,
140142
None,
@@ -179,7 +181,7 @@ impl<'a> DownloadCfg<'a> {
179181
let file = self.tmp_cx.new_file_with_ext("", ext)?;
180182

181183
let mut hasher = Sha256::new();
182-
utils::download_file(
184+
download_file(
183185
&url,
184186
&file,
185187
Some(&mut hasher),

src/dist/manifestation/tests.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use crate::{
2323
prefix::InstallPrefix,
2424
temp,
2525
},
26+
download::download_file,
2627
errors::RustupError,
2728
process::TestProcess,
2829
test::{
@@ -495,7 +496,7 @@ impl TestContext {
495496
// Download the dist manifest and place it into the installation prefix
496497
let manifest_url = make_manifest_url(&self.url, &self.toolchain)?;
497498
let manifest_file = self.tmp_cx.new_file()?;
498-
utils::download_file(&manifest_url, &manifest_file, None, &|_| {}, dl_cfg.process).await?;
499+
download_file(&manifest_url, &manifest_file, None, &|_| {}, dl_cfg.process).await?;
499500
let manifest_str = utils::read_file("manifest", &manifest_file)?;
500501
let manifest = Manifest::parse(&manifest_str)?;
501502

src/download/mod.rs

+189-9
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,192 @@ use std::fs::remove_file;
44
use std::path::Path;
55

66
use anyhow::{Context, Result};
7+
use sha2::Sha256;
78
use thiserror::Error;
9+
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
10+
use tracing::info;
811
use url::Url;
912

13+
use crate::{errors::RustupError, process::Process, utils::Notification};
14+
1015
#[cfg(test)]
1116
mod tests;
1217

18+
pub(crate) async fn download_file(
19+
url: &Url,
20+
path: &Path,
21+
hasher: Option<&mut Sha256>,
22+
notify_handler: &dyn Fn(Notification<'_>),
23+
process: &Process,
24+
) -> Result<()> {
25+
download_file_with_resume(url, path, hasher, false, &notify_handler, process).await
26+
}
27+
28+
pub(crate) async fn download_file_with_resume(
29+
url: &Url,
30+
path: &Path,
31+
hasher: Option<&mut Sha256>,
32+
resume_from_partial: bool,
33+
notify_handler: &dyn Fn(Notification<'_>),
34+
process: &Process,
35+
) -> Result<()> {
36+
use crate::download::DownloadError as DEK;
37+
match download_file_(
38+
url,
39+
path,
40+
hasher,
41+
resume_from_partial,
42+
notify_handler,
43+
process,
44+
)
45+
.await
46+
{
47+
Ok(_) => Ok(()),
48+
Err(e) => {
49+
if e.downcast_ref::<std::io::Error>().is_some() {
50+
return Err(e);
51+
}
52+
let is_client_error = match e.downcast_ref::<DEK>() {
53+
// Specifically treat the bad partial range error as not our
54+
// fault in case it was something odd which happened.
55+
Some(DEK::HttpStatus(416)) => false,
56+
Some(DEK::HttpStatus(400..=499)) | Some(DEK::FileNotFound) => true,
57+
_ => false,
58+
};
59+
Err(e).with_context(|| {
60+
if is_client_error {
61+
RustupError::DownloadNotExists {
62+
url: url.clone(),
63+
path: path.to_path_buf(),
64+
}
65+
} else {
66+
RustupError::DownloadingFile {
67+
url: url.clone(),
68+
path: path.to_path_buf(),
69+
}
70+
}
71+
})
72+
}
73+
}
74+
}
75+
76+
async fn download_file_(
77+
url: &Url,
78+
path: &Path,
79+
hasher: Option<&mut Sha256>,
80+
resume_from_partial: bool,
81+
notify_handler: &dyn Fn(Notification<'_>),
82+
process: &Process,
83+
) -> Result<()> {
84+
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
85+
use crate::download::{Backend, Event, TlsBackend};
86+
use sha2::Digest;
87+
use std::cell::RefCell;
88+
89+
notify_handler(Notification::DownloadingFile(url, path));
90+
91+
let hasher = RefCell::new(hasher);
92+
93+
// This callback will write the download to disk and optionally
94+
// hash the contents, then forward the notification up the stack
95+
let callback: &dyn Fn(Event<'_>) -> anyhow::Result<()> = &|msg| {
96+
if let Event::DownloadDataReceived(data) = msg {
97+
if let Some(h) = hasher.borrow_mut().as_mut() {
98+
h.update(data);
99+
}
100+
}
101+
102+
match msg {
103+
Event::DownloadContentLengthReceived(len) => {
104+
notify_handler(Notification::DownloadContentLengthReceived(len));
105+
}
106+
Event::DownloadDataReceived(data) => {
107+
notify_handler(Notification::DownloadDataReceived(data));
108+
}
109+
Event::ResumingPartialDownload => {
110+
notify_handler(Notification::ResumingPartialDownload);
111+
}
112+
}
113+
114+
Ok(())
115+
};
116+
117+
// Download the file
118+
119+
// Keep the curl env var around for a bit
120+
let use_curl_backend = process.var_os("RUSTUP_USE_CURL").map(|it| it != "0");
121+
let use_rustls = process.var_os("RUSTUP_USE_RUSTLS").map(|it| it != "0");
122+
123+
let backend = match (use_curl_backend, use_rustls) {
124+
// If environment specifies a backend that's unavailable, error out
125+
#[cfg(not(feature = "reqwest-rustls-tls"))]
126+
(_, Some(true)) => {
127+
return Err(anyhow!(
128+
"RUSTUP_USE_RUSTLS is set, but this rustup distribution was not built with the reqwest-rustls-tls feature"
129+
));
130+
}
131+
#[cfg(not(feature = "reqwest-native-tls"))]
132+
(_, Some(false)) => {
133+
return Err(anyhow!(
134+
"RUSTUP_USE_RUSTLS is set to false, but this rustup distribution was not built with the reqwest-native-tls feature"
135+
));
136+
}
137+
#[cfg(not(feature = "curl-backend"))]
138+
(Some(true), _) => {
139+
return Err(anyhow!(
140+
"RUSTUP_USE_CURL is set, but this rustup distribution was not built with the curl-backend feature"
141+
));
142+
}
143+
144+
// Positive selections, from least preferred to most preferred
145+
#[cfg(feature = "curl-backend")]
146+
(Some(true), None) => Backend::Curl,
147+
#[cfg(feature = "reqwest-native-tls")]
148+
(_, Some(false)) => {
149+
if use_curl_backend == Some(true) {
150+
info!(
151+
"RUSTUP_USE_CURL is set and RUSTUP_USE_RUSTLS is set to off, using reqwest with native-tls"
152+
);
153+
}
154+
Backend::Reqwest(TlsBackend::NativeTls)
155+
}
156+
#[cfg(feature = "reqwest-rustls-tls")]
157+
_ => {
158+
if use_curl_backend == Some(true) {
159+
info!(
160+
"both RUSTUP_USE_CURL and RUSTUP_USE_RUSTLS are set, using reqwest with rustls"
161+
);
162+
}
163+
Backend::Reqwest(TlsBackend::Rustls)
164+
}
165+
166+
// Falling back if only one backend is available
167+
#[cfg(all(not(feature = "reqwest-rustls-tls"), feature = "reqwest-native-tls"))]
168+
_ => Backend::Reqwest(TlsBackend::NativeTls),
169+
#[cfg(all(
170+
not(feature = "reqwest-rustls-tls"),
171+
not(feature = "reqwest-native-tls"),
172+
feature = "curl-backend"
173+
))]
174+
_ => Backend::Curl,
175+
};
176+
177+
notify_handler(match backend {
178+
#[cfg(feature = "curl-backend")]
179+
Backend::Curl => Notification::UsingCurl,
180+
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
181+
Backend::Reqwest(_) => Notification::UsingReqwest,
182+
});
183+
184+
let res = backend
185+
.download_to_path(url, path, resume_from_partial, Some(callback))
186+
.await;
187+
188+
notify_handler(Notification::DownloadFinished);
189+
190+
res
191+
}
192+
13193
/// User agent header value for HTTP request.
14194
/// See: https://github.com/rust-lang/rustup/issues/2860.
15195
#[cfg(feature = "curl-backend")]
@@ -27,15 +207,15 @@ const REQWEST_RUSTLS_TLS_USER_AGENT: &str =
27207
concat!("rustup/", env!("CARGO_PKG_VERSION"), " (reqwest; rustls)");
28208

29209
#[derive(Debug, Copy, Clone)]
30-
pub enum Backend {
210+
enum Backend {
31211
#[cfg(feature = "curl-backend")]
32212
Curl,
33213
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
34214
Reqwest(TlsBackend),
35215
}
36216

37217
impl Backend {
38-
pub async fn download_to_path(
218+
async fn download_to_path(
39219
self,
40220
url: &Url,
41221
path: &Path,
@@ -169,7 +349,7 @@ impl Backend {
169349

170350
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
171351
#[derive(Debug, Copy, Clone)]
172-
pub enum TlsBackend {
352+
enum TlsBackend {
173353
#[cfg(feature = "reqwest-rustls-tls")]
174354
Rustls,
175355
#[cfg(feature = "reqwest-native-tls")]
@@ -196,7 +376,7 @@ impl TlsBackend {
196376
}
197377

198378
#[derive(Debug, Copy, Clone)]
199-
pub enum Event<'a> {
379+
enum Event<'a> {
200380
ResumingPartialDownload,
201381
/// Received the Content-Length of the to-be downloaded data.
202382
DownloadContentLengthReceived(u64),
@@ -209,7 +389,7 @@ type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> Result<()>;
209389
/// Download via libcurl; encrypt with the native (or OpenSSl) TLS
210390
/// stack via libcurl
211391
#[cfg(feature = "curl-backend")]
212-
pub mod curl {
392+
mod curl {
213393
use std::cell::RefCell;
214394
use std::str;
215395
use std::time::Duration;
@@ -220,7 +400,7 @@ pub mod curl {
220400

221401
use super::{DownloadError, Event};
222402

223-
pub fn download(
403+
pub(super) fn download(
224404
url: &Url,
225405
resume_from: u64,
226406
callback: &dyn Fn(Event<'_>) -> Result<()>,
@@ -321,7 +501,7 @@ pub mod curl {
321501
}
322502

323503
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
324-
pub mod reqwest_be {
504+
mod reqwest_be {
325505
use std::io;
326506
#[cfg(feature = "reqwest-rustls-tls")]
327507
use std::sync::Arc;
@@ -340,7 +520,7 @@ pub mod reqwest_be {
340520

341521
use super::{DownloadError, Event};
342522

343-
pub async fn download(
523+
pub(super) async fn download(
344524
url: &Url,
345525
resume_from: u64,
346526
callback: &dyn Fn(Event<'_>) -> Result<()>,
@@ -485,7 +665,7 @@ pub mod reqwest_be {
485665
}
486666

487667
#[derive(Debug, Error)]
488-
pub enum DownloadError {
668+
enum DownloadError {
489669
#[error("http request returned an unsuccessful status code: {0}")]
490670
HttpStatus(u32),
491671
#[error("file not found")]

0 commit comments

Comments
 (0)