Skip to content

Commit 0788cd6

Browse files
djcrami3l
authored andcommitted
Simplify download API abstraction
1 parent a021acd commit 0788cd6

File tree

6 files changed

+201
-194
lines changed

6 files changed

+201
-194
lines changed

src/cli/self_update.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ use same_file::Handle;
4848
use serde::{Deserialize, Serialize};
4949
use tracing::{error, info, trace, warn};
5050

51+
use crate::download::download_file;
5152
use crate::{
5253
DUP_TOOLS, TOOLS,
5354
cli::{
@@ -1134,7 +1135,7 @@ pub(crate) async fn prepare_update(process: &Process) -> Result<Option<PathBuf>>
11341135

11351136
// Download new version
11361137
info!("downloading self-update");
1137-
utils::download_file(&download_url, &setup_path, None, &|_| (), process).await?;
1138+
download_file(&download_url, &setup_path, None, &|_| (), process).await?;
11381139

11391140
// Mark as executable
11401141
utils::make_executable(&setup_path)?;
@@ -1153,7 +1154,7 @@ async fn get_available_rustup_version(process: &Process) -> Result<String> {
11531154
let release_file_url = format!("{update_root}/release-stable.toml");
11541155
let release_file_url = utils::parse_url(&release_file_url)?;
11551156
let release_file = tempdir.path().join("release-stable.toml");
1156-
utils::download_file(&release_file_url, &release_file, None, &|_| (), process).await?;
1157+
download_file(&release_file_url, &release_file, None, &|_| (), process).await?;
11571158
let release_toml_str = utils::read_file("rustup release", &release_file)?;
11581159
let release_toml = toml::from_str::<RustupManifest>(&release_toml_str)
11591160
.context("unable to parse rustup release file")?;

src/cli/self_update/windows.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use super::common;
2222
use super::{InstallOpts, install_bins, report_error};
2323
use crate::cli::{download_tracker::DownloadTracker, markdown::md};
2424
use crate::dist::TargetTriple;
25+
use crate::download::download_file;
2526
use crate::process::{Process, terminalsource::ColorableTerminal};
2627
use crate::utils::{self, Notification};
2728

@@ -276,7 +277,7 @@ pub(crate) async fn try_install_msvc(
276277
download_tracker.lock().unwrap().download_finished();
277278

278279
info!("downloading Visual Studio installer");
279-
utils::download_file(
280+
download_file(
280281
&visual_studio_url,
281282
&visual_studio,
282283
None,

src/dist/download.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use url::Url;
88

99
use crate::dist::notifications::*;
1010
use crate::dist::temp;
11+
use crate::download::download_file;
12+
use crate::download::download_file_with_resume;
1113
use crate::errors::*;
1214
use crate::process::Process;
1315
use crate::utils;
@@ -73,7 +75,7 @@ impl<'a> DownloadCfg<'a> {
7375

7476
let mut hasher = Sha256::new();
7577

76-
if let Err(e) = utils::download_file_with_resume(
78+
if let Err(e) = download_file_with_resume(
7779
url,
7880
&partial_file_path,
7981
Some(&mut hasher),
@@ -134,7 +136,7 @@ impl<'a> DownloadCfg<'a> {
134136
let hash_url = utils::parse_url(&(url.to_owned() + ".sha256"))?;
135137
let hash_file = self.tmp_cx.new_file()?;
136138

137-
utils::download_file(
139+
download_file(
138140
&hash_url,
139141
&hash_file,
140142
None,
@@ -179,7 +181,7 @@ impl<'a> DownloadCfg<'a> {
179181
let file = self.tmp_cx.new_file_with_ext("", ext)?;
180182

181183
let mut hasher = Sha256::new();
182-
utils::download_file(
184+
download_file(
183185
&url,
184186
&file,
185187
Some(&mut hasher),

src/dist/manifestation/tests.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use crate::{
2323
prefix::InstallPrefix,
2424
temp,
2525
},
26+
download::download_file,
2627
errors::RustupError,
2728
process::TestProcess,
2829
test::{
@@ -495,7 +496,7 @@ impl TestContext {
495496
// Download the dist manifest and place it into the installation prefix
496497
let manifest_url = make_manifest_url(&self.url, &self.toolchain)?;
497498
let manifest_file = self.tmp_cx.new_file()?;
498-
utils::download_file(&manifest_url, &manifest_file, None, &|_| {}, dl_cfg.process).await?;
499+
download_file(&manifest_url, &manifest_file, None, &|_| {}, dl_cfg.process).await?;
499500
let manifest_str = utils::read_file("manifest", &manifest_file)?;
500501
let manifest = Manifest::parse(&manifest_str)?;
501502

src/download/mod.rs

+189-9
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,192 @@ use std::path::Path;
1010
))]
1111
use anyhow::anyhow;
1212
use anyhow::{Context, Result};
13+
use sha2::Sha256;
1314
use thiserror::Error;
15+
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
16+
use tracing::info;
1417
use url::Url;
1518

19+
use crate::{errors::RustupError, process::Process, utils::Notification};
20+
1621
#[cfg(test)]
1722
mod tests;
1823

24+
pub(crate) async fn download_file(
25+
url: &Url,
26+
path: &Path,
27+
hasher: Option<&mut Sha256>,
28+
notify_handler: &dyn Fn(Notification<'_>),
29+
process: &Process,
30+
) -> Result<()> {
31+
download_file_with_resume(url, path, hasher, false, &notify_handler, process).await
32+
}
33+
34+
pub(crate) async fn download_file_with_resume(
35+
url: &Url,
36+
path: &Path,
37+
hasher: Option<&mut Sha256>,
38+
resume_from_partial: bool,
39+
notify_handler: &dyn Fn(Notification<'_>),
40+
process: &Process,
41+
) -> Result<()> {
42+
use crate::download::DownloadError as DEK;
43+
match download_file_(
44+
url,
45+
path,
46+
hasher,
47+
resume_from_partial,
48+
notify_handler,
49+
process,
50+
)
51+
.await
52+
{
53+
Ok(_) => Ok(()),
54+
Err(e) => {
55+
if e.downcast_ref::<std::io::Error>().is_some() {
56+
return Err(e);
57+
}
58+
let is_client_error = match e.downcast_ref::<DEK>() {
59+
// Specifically treat the bad partial range error as not our
60+
// fault in case it was something odd which happened.
61+
Some(DEK::HttpStatus(416)) => false,
62+
Some(DEK::HttpStatus(400..=499)) | Some(DEK::FileNotFound) => true,
63+
_ => false,
64+
};
65+
Err(e).with_context(|| {
66+
if is_client_error {
67+
RustupError::DownloadNotExists {
68+
url: url.clone(),
69+
path: path.to_path_buf(),
70+
}
71+
} else {
72+
RustupError::DownloadingFile {
73+
url: url.clone(),
74+
path: path.to_path_buf(),
75+
}
76+
}
77+
})
78+
}
79+
}
80+
}
81+
82+
async fn download_file_(
83+
url: &Url,
84+
path: &Path,
85+
hasher: Option<&mut Sha256>,
86+
resume_from_partial: bool,
87+
notify_handler: &dyn Fn(Notification<'_>),
88+
process: &Process,
89+
) -> Result<()> {
90+
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
91+
use crate::download::{Backend, Event, TlsBackend};
92+
use sha2::Digest;
93+
use std::cell::RefCell;
94+
95+
notify_handler(Notification::DownloadingFile(url, path));
96+
97+
let hasher = RefCell::new(hasher);
98+
99+
// This callback will write the download to disk and optionally
100+
// hash the contents, then forward the notification up the stack
101+
let callback: &dyn Fn(Event<'_>) -> anyhow::Result<()> = &|msg| {
102+
if let Event::DownloadDataReceived(data) = msg {
103+
if let Some(h) = hasher.borrow_mut().as_mut() {
104+
h.update(data);
105+
}
106+
}
107+
108+
match msg {
109+
Event::DownloadContentLengthReceived(len) => {
110+
notify_handler(Notification::DownloadContentLengthReceived(len));
111+
}
112+
Event::DownloadDataReceived(data) => {
113+
notify_handler(Notification::DownloadDataReceived(data));
114+
}
115+
Event::ResumingPartialDownload => {
116+
notify_handler(Notification::ResumingPartialDownload);
117+
}
118+
}
119+
120+
Ok(())
121+
};
122+
123+
// Download the file
124+
125+
// Keep the curl env var around for a bit
126+
let use_curl_backend = process.var_os("RUSTUP_USE_CURL").map(|it| it != "0");
127+
let use_rustls = process.var_os("RUSTUP_USE_RUSTLS").map(|it| it != "0");
128+
129+
let backend = match (use_curl_backend, use_rustls) {
130+
// If environment specifies a backend that's unavailable, error out
131+
#[cfg(not(feature = "reqwest-rustls-tls"))]
132+
(_, Some(true)) => {
133+
return Err(anyhow!(
134+
"RUSTUP_USE_RUSTLS is set, but this rustup distribution was not built with the reqwest-rustls-tls feature"
135+
));
136+
}
137+
#[cfg(not(feature = "reqwest-native-tls"))]
138+
(_, Some(false)) => {
139+
return Err(anyhow!(
140+
"RUSTUP_USE_RUSTLS is set to false, but this rustup distribution was not built with the reqwest-native-tls feature"
141+
));
142+
}
143+
#[cfg(not(feature = "curl-backend"))]
144+
(Some(true), _) => {
145+
return Err(anyhow!(
146+
"RUSTUP_USE_CURL is set, but this rustup distribution was not built with the curl-backend feature"
147+
));
148+
}
149+
150+
// Positive selections, from least preferred to most preferred
151+
#[cfg(feature = "curl-backend")]
152+
(Some(true), None) => Backend::Curl,
153+
#[cfg(feature = "reqwest-native-tls")]
154+
(_, Some(false)) => {
155+
if use_curl_backend == Some(true) {
156+
info!(
157+
"RUSTUP_USE_CURL is set and RUSTUP_USE_RUSTLS is set to off, using reqwest with native-tls"
158+
);
159+
}
160+
Backend::Reqwest(TlsBackend::NativeTls)
161+
}
162+
#[cfg(feature = "reqwest-rustls-tls")]
163+
_ => {
164+
if use_curl_backend == Some(true) {
165+
info!(
166+
"both RUSTUP_USE_CURL and RUSTUP_USE_RUSTLS are set, using reqwest with rustls"
167+
);
168+
}
169+
Backend::Reqwest(TlsBackend::Rustls)
170+
}
171+
172+
// Falling back if only one backend is available
173+
#[cfg(all(not(feature = "reqwest-rustls-tls"), feature = "reqwest-native-tls"))]
174+
_ => Backend::Reqwest(TlsBackend::NativeTls),
175+
#[cfg(all(
176+
not(feature = "reqwest-rustls-tls"),
177+
not(feature = "reqwest-native-tls"),
178+
feature = "curl-backend"
179+
))]
180+
_ => Backend::Curl,
181+
};
182+
183+
notify_handler(match backend {
184+
#[cfg(feature = "curl-backend")]
185+
Backend::Curl => Notification::UsingCurl,
186+
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
187+
Backend::Reqwest(_) => Notification::UsingReqwest,
188+
});
189+
190+
let res = backend
191+
.download_to_path(url, path, resume_from_partial, Some(callback))
192+
.await;
193+
194+
notify_handler(Notification::DownloadFinished);
195+
196+
res
197+
}
198+
19199
/// User agent header value for HTTP request.
20200
/// See: https://github.com/rust-lang/rustup/issues/2860.
21201
#[cfg(feature = "curl-backend")]
@@ -33,15 +213,15 @@ const REQWEST_RUSTLS_TLS_USER_AGENT: &str =
33213
concat!("rustup/", env!("CARGO_PKG_VERSION"), " (reqwest; rustls)");
34214

35215
#[derive(Debug, Copy, Clone)]
36-
pub enum Backend {
216+
enum Backend {
37217
#[cfg(feature = "curl-backend")]
38218
Curl,
39219
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
40220
Reqwest(TlsBackend),
41221
}
42222

43223
impl Backend {
44-
pub async fn download_to_path(
224+
async fn download_to_path(
45225
self,
46226
url: &Url,
47227
path: &Path,
@@ -175,7 +355,7 @@ impl Backend {
175355

176356
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
177357
#[derive(Debug, Copy, Clone)]
178-
pub enum TlsBackend {
358+
enum TlsBackend {
179359
#[cfg(feature = "reqwest-rustls-tls")]
180360
Rustls,
181361
#[cfg(feature = "reqwest-native-tls")]
@@ -202,7 +382,7 @@ impl TlsBackend {
202382
}
203383

204384
#[derive(Debug, Copy, Clone)]
205-
pub enum Event<'a> {
385+
enum Event<'a> {
206386
ResumingPartialDownload,
207387
/// Received the Content-Length of the to-be downloaded data.
208388
DownloadContentLengthReceived(u64),
@@ -215,7 +395,7 @@ type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> Result<()>;
215395
/// Download via libcurl; encrypt with the native (or OpenSSl) TLS
216396
/// stack via libcurl
217397
#[cfg(feature = "curl-backend")]
218-
pub mod curl {
398+
mod curl {
219399
use std::cell::RefCell;
220400
use std::str;
221401
use std::time::Duration;
@@ -226,7 +406,7 @@ pub mod curl {
226406

227407
use super::{DownloadError, Event};
228408

229-
pub fn download(
409+
pub(super) fn download(
230410
url: &Url,
231411
resume_from: u64,
232412
callback: &dyn Fn(Event<'_>) -> Result<()>,
@@ -327,7 +507,7 @@ pub mod curl {
327507
}
328508

329509
#[cfg(any(feature = "reqwest-rustls-tls", feature = "reqwest-native-tls"))]
330-
pub mod reqwest_be {
510+
mod reqwest_be {
331511
use std::io;
332512
#[cfg(feature = "reqwest-rustls-tls")]
333513
use std::sync::Arc;
@@ -346,7 +526,7 @@ pub mod reqwest_be {
346526

347527
use super::{DownloadError, Event};
348528

349-
pub async fn download(
529+
pub(super) async fn download(
350530
url: &Url,
351531
resume_from: u64,
352532
callback: &dyn Fn(Event<'_>) -> Result<()>,
@@ -491,7 +671,7 @@ pub mod reqwest_be {
491671
}
492672

493673
#[derive(Debug, Error)]
494-
pub enum DownloadError {
674+
enum DownloadError {
495675
#[error("http request returned an unsuccessful status code: {0}")]
496676
HttpStatus(u32),
497677
#[error("file not found")]

0 commit comments

Comments
 (0)