@@ -10,12 +10,192 @@ use std::path::Path;
10
10
) ) ]
11
11
use anyhow:: anyhow;
12
12
use anyhow:: { Context , Result } ;
13
+ use sha2:: Sha256 ;
13
14
use thiserror:: Error ;
15
+ #[ cfg( any( feature = "reqwest-rustls-tls" , feature = "reqwest-native-tls" ) ) ]
16
+ use tracing:: info;
14
17
use url:: Url ;
15
18
19
+ use crate :: { errors:: RustupError , process:: Process , utils:: Notification } ;
20
+
16
21
#[ cfg( test) ]
17
22
mod tests;
18
23
24
+ pub ( crate ) async fn download_file (
25
+ url : & Url ,
26
+ path : & Path ,
27
+ hasher : Option < & mut Sha256 > ,
28
+ notify_handler : & dyn Fn ( Notification < ' _ > ) ,
29
+ process : & Process ,
30
+ ) -> Result < ( ) > {
31
+ download_file_with_resume ( url, path, hasher, false , & notify_handler, process) . await
32
+ }
33
+
34
+ pub ( crate ) async fn download_file_with_resume (
35
+ url : & Url ,
36
+ path : & Path ,
37
+ hasher : Option < & mut Sha256 > ,
38
+ resume_from_partial : bool ,
39
+ notify_handler : & dyn Fn ( Notification < ' _ > ) ,
40
+ process : & Process ,
41
+ ) -> Result < ( ) > {
42
+ use crate :: download:: DownloadError as DEK ;
43
+ match download_file_ (
44
+ url,
45
+ path,
46
+ hasher,
47
+ resume_from_partial,
48
+ notify_handler,
49
+ process,
50
+ )
51
+ . await
52
+ {
53
+ Ok ( _) => Ok ( ( ) ) ,
54
+ Err ( e) => {
55
+ if e. downcast_ref :: < std:: io:: Error > ( ) . is_some ( ) {
56
+ return Err ( e) ;
57
+ }
58
+ let is_client_error = match e. downcast_ref :: < DEK > ( ) {
59
+ // Specifically treat the bad partial range error as not our
60
+ // fault in case it was something odd which happened.
61
+ Some ( DEK :: HttpStatus ( 416 ) ) => false ,
62
+ Some ( DEK :: HttpStatus ( 400 ..=499 ) ) | Some ( DEK :: FileNotFound ) => true ,
63
+ _ => false ,
64
+ } ;
65
+ Err ( e) . with_context ( || {
66
+ if is_client_error {
67
+ RustupError :: DownloadNotExists {
68
+ url : url. clone ( ) ,
69
+ path : path. to_path_buf ( ) ,
70
+ }
71
+ } else {
72
+ RustupError :: DownloadingFile {
73
+ url : url. clone ( ) ,
74
+ path : path. to_path_buf ( ) ,
75
+ }
76
+ }
77
+ } )
78
+ }
79
+ }
80
+ }
81
+
82
+ async fn download_file_ (
83
+ url : & Url ,
84
+ path : & Path ,
85
+ hasher : Option < & mut Sha256 > ,
86
+ resume_from_partial : bool ,
87
+ notify_handler : & dyn Fn ( Notification < ' _ > ) ,
88
+ process : & Process ,
89
+ ) -> Result < ( ) > {
90
+ #[ cfg( any( feature = "reqwest-rustls-tls" , feature = "reqwest-native-tls" ) ) ]
91
+ use crate :: download:: { Backend , Event , TlsBackend } ;
92
+ use sha2:: Digest ;
93
+ use std:: cell:: RefCell ;
94
+
95
+ notify_handler ( Notification :: DownloadingFile ( url, path) ) ;
96
+
97
+ let hasher = RefCell :: new ( hasher) ;
98
+
99
+ // This callback will write the download to disk and optionally
100
+ // hash the contents, then forward the notification up the stack
101
+ let callback: & dyn Fn ( Event < ' _ > ) -> anyhow:: Result < ( ) > = & |msg| {
102
+ if let Event :: DownloadDataReceived ( data) = msg {
103
+ if let Some ( h) = hasher. borrow_mut ( ) . as_mut ( ) {
104
+ h. update ( data) ;
105
+ }
106
+ }
107
+
108
+ match msg {
109
+ Event :: DownloadContentLengthReceived ( len) => {
110
+ notify_handler ( Notification :: DownloadContentLengthReceived ( len) ) ;
111
+ }
112
+ Event :: DownloadDataReceived ( data) => {
113
+ notify_handler ( Notification :: DownloadDataReceived ( data) ) ;
114
+ }
115
+ Event :: ResumingPartialDownload => {
116
+ notify_handler ( Notification :: ResumingPartialDownload ) ;
117
+ }
118
+ }
119
+
120
+ Ok ( ( ) )
121
+ } ;
122
+
123
+ // Download the file
124
+
125
+ // Keep the curl env var around for a bit
126
+ let use_curl_backend = process. var_os ( "RUSTUP_USE_CURL" ) . map ( |it| it != "0" ) ;
127
+ let use_rustls = process. var_os ( "RUSTUP_USE_RUSTLS" ) . map ( |it| it != "0" ) ;
128
+
129
+ let backend = match ( use_curl_backend, use_rustls) {
130
+ // If environment specifies a backend that's unavailable, error out
131
+ #[ cfg( not( feature = "reqwest-rustls-tls" ) ) ]
132
+ ( _, Some ( true ) ) => {
133
+ return Err ( anyhow ! (
134
+ "RUSTUP_USE_RUSTLS is set, but this rustup distribution was not built with the reqwest-rustls-tls feature"
135
+ ) ) ;
136
+ }
137
+ #[ cfg( not( feature = "reqwest-native-tls" ) ) ]
138
+ ( _, Some ( false ) ) => {
139
+ return Err ( anyhow ! (
140
+ "RUSTUP_USE_RUSTLS is set to false, but this rustup distribution was not built with the reqwest-native-tls feature"
141
+ ) ) ;
142
+ }
143
+ #[ cfg( not( feature = "curl-backend" ) ) ]
144
+ ( Some ( true ) , _) => {
145
+ return Err ( anyhow ! (
146
+ "RUSTUP_USE_CURL is set, but this rustup distribution was not built with the curl-backend feature"
147
+ ) ) ;
148
+ }
149
+
150
+ // Positive selections, from least preferred to most preferred
151
+ #[ cfg( feature = "curl-backend" ) ]
152
+ ( Some ( true ) , None ) => Backend :: Curl ,
153
+ #[ cfg( feature = "reqwest-native-tls" ) ]
154
+ ( _, Some ( false ) ) => {
155
+ if use_curl_backend == Some ( true ) {
156
+ info ! (
157
+ "RUSTUP_USE_CURL is set and RUSTUP_USE_RUSTLS is set to off, using reqwest with native-tls"
158
+ ) ;
159
+ }
160
+ Backend :: Reqwest ( TlsBackend :: NativeTls )
161
+ }
162
+ #[ cfg( feature = "reqwest-rustls-tls" ) ]
163
+ _ => {
164
+ if use_curl_backend == Some ( true ) {
165
+ info ! (
166
+ "both RUSTUP_USE_CURL and RUSTUP_USE_RUSTLS are set, using reqwest with rustls"
167
+ ) ;
168
+ }
169
+ Backend :: Reqwest ( TlsBackend :: Rustls )
170
+ }
171
+
172
+ // Falling back if only one backend is available
173
+ #[ cfg( all( not( feature = "reqwest-rustls-tls" ) , feature = "reqwest-native-tls" ) ) ]
174
+ _ => Backend :: Reqwest ( TlsBackend :: NativeTls ) ,
175
+ #[ cfg( all(
176
+ not( feature = "reqwest-rustls-tls" ) ,
177
+ not( feature = "reqwest-native-tls" ) ,
178
+ feature = "curl-backend"
179
+ ) ) ]
180
+ _ => Backend :: Curl ,
181
+ } ;
182
+
183
+ notify_handler ( match backend {
184
+ #[ cfg( feature = "curl-backend" ) ]
185
+ Backend :: Curl => Notification :: UsingCurl ,
186
+ #[ cfg( any( feature = "reqwest-rustls-tls" , feature = "reqwest-native-tls" ) ) ]
187
+ Backend :: Reqwest ( _) => Notification :: UsingReqwest ,
188
+ } ) ;
189
+
190
+ let res = backend
191
+ . download_to_path ( url, path, resume_from_partial, Some ( callback) )
192
+ . await ;
193
+
194
+ notify_handler ( Notification :: DownloadFinished ) ;
195
+
196
+ res
197
+ }
198
+
19
199
/// User agent header value for HTTP request.
20
200
/// See: https://github.com/rust-lang/rustup/issues/2860.
21
201
#[ cfg( feature = "curl-backend" ) ]
@@ -33,15 +213,15 @@ const REQWEST_RUSTLS_TLS_USER_AGENT: &str =
33
213
concat ! ( "rustup/" , env!( "CARGO_PKG_VERSION" ) , " (reqwest; rustls)" ) ;
34
214
35
215
#[ derive( Debug , Copy , Clone ) ]
36
- pub enum Backend {
216
+ enum Backend {
37
217
#[ cfg( feature = "curl-backend" ) ]
38
218
Curl ,
39
219
#[ cfg( any( feature = "reqwest-rustls-tls" , feature = "reqwest-native-tls" ) ) ]
40
220
Reqwest ( TlsBackend ) ,
41
221
}
42
222
43
223
impl Backend {
44
- pub async fn download_to_path (
224
+ async fn download_to_path (
45
225
self ,
46
226
url : & Url ,
47
227
path : & Path ,
@@ -175,7 +355,7 @@ impl Backend {
175
355
176
356
#[ cfg( any( feature = "reqwest-rustls-tls" , feature = "reqwest-native-tls" ) ) ]
177
357
#[ derive( Debug , Copy , Clone ) ]
178
- pub enum TlsBackend {
358
+ enum TlsBackend {
179
359
#[ cfg( feature = "reqwest-rustls-tls" ) ]
180
360
Rustls ,
181
361
#[ cfg( feature = "reqwest-native-tls" ) ]
@@ -202,7 +382,7 @@ impl TlsBackend {
202
382
}
203
383
204
384
#[ derive( Debug , Copy , Clone ) ]
205
- pub enum Event < ' a > {
385
+ enum Event < ' a > {
206
386
ResumingPartialDownload ,
207
387
/// Received the Content-Length of the to-be downloaded data.
208
388
DownloadContentLengthReceived ( u64 ) ,
@@ -215,7 +395,7 @@ type DownloadCallback<'a> = &'a dyn Fn(Event<'_>) -> Result<()>;
215
395
/// Download via libcurl; encrypt with the native (or OpenSSl) TLS
216
396
/// stack via libcurl
217
397
#[ cfg( feature = "curl-backend" ) ]
218
- pub mod curl {
398
+ mod curl {
219
399
use std:: cell:: RefCell ;
220
400
use std:: str;
221
401
use std:: time:: Duration ;
@@ -226,7 +406,7 @@ pub mod curl {
226
406
227
407
use super :: { DownloadError , Event } ;
228
408
229
- pub fn download (
409
+ pub ( super ) fn download (
230
410
url : & Url ,
231
411
resume_from : u64 ,
232
412
callback : & dyn Fn ( Event < ' _ > ) -> Result < ( ) > ,
@@ -327,7 +507,7 @@ pub mod curl {
327
507
}
328
508
329
509
#[ cfg( any( feature = "reqwest-rustls-tls" , feature = "reqwest-native-tls" ) ) ]
330
- pub mod reqwest_be {
510
+ mod reqwest_be {
331
511
use std:: io;
332
512
#[ cfg( feature = "reqwest-rustls-tls" ) ]
333
513
use std:: sync:: Arc ;
@@ -346,7 +526,7 @@ pub mod reqwest_be {
346
526
347
527
use super :: { DownloadError , Event } ;
348
528
349
- pub async fn download (
529
+ pub ( super ) async fn download (
350
530
url : & Url ,
351
531
resume_from : u64 ,
352
532
callback : & dyn Fn ( Event < ' _ > ) -> Result < ( ) > ,
@@ -491,7 +671,7 @@ pub mod reqwest_be {
491
671
}
492
672
493
673
#[ derive( Debug , Error ) ]
494
- pub enum DownloadError {
674
+ enum DownloadError {
495
675
#[ error( "http request returned an unsuccessful status code: {0}" ) ]
496
676
HttpStatus ( u32 ) ,
497
677
#[ error( "file not found" ) ]
0 commit comments