4
4
//!
5
5
//! You can learn more about this authorization flow [here](https://docs.microsoft.com/azure/active-directory/develop/v2-oauth2-device-code).
6
6
mod device_code_responses;
7
- use crate :: Error ;
8
- use async_timer:: timer:: new_timer;
7
+
9
8
pub use device_code_responses:: * ;
9
+
10
+ use async_timer:: timer:: new_timer;
10
11
use futures:: stream:: unfold;
11
- use log:: debug;
12
12
use oauth2:: ClientId ;
13
13
use serde:: Deserialize ;
14
+ use url:: form_urlencoded;
15
+
14
16
use std:: borrow:: Cow ;
15
17
use std:: convert:: TryInto ;
16
18
use std:: time:: Duration ;
17
- use url:: form_urlencoded;
18
19
19
20
pub async fn start < ' a , ' b , T > (
20
21
client : & ' a reqwest:: Client ,
21
22
tenant_id : T ,
22
23
client_id : & ' a ClientId ,
23
24
scopes : & ' b [ & ' b str ] ,
24
- ) -> Result < DeviceCodePhaseOneResponse < ' a > , Error >
25
+ ) -> Result < DeviceCodePhaseOneResponse < ' a > , DeviceCodeError >
25
26
where
26
27
T : Into < Cow < ' a , str > > ,
27
28
{
@@ -32,41 +33,48 @@ where
32
33
33
34
let tenant_id = tenant_id. into ( ) ;
34
35
35
- debug ! ( "encoded ==> {}" , encoded) ;
36
-
37
36
let url = url:: Url :: parse ( & format ! (
38
37
"https://login.microsoftonline.com/{}/oauth2/v2.0/devicecode" ,
39
38
tenant_id
40
- ) ) ?;
39
+ ) )
40
+ . map_err ( |_| DeviceCodeError :: InvalidTenantId ( tenant_id. clone ( ) . into_owned ( ) ) ) ?;
41
41
42
- client
42
+ let response = client
43
43
. post ( url)
44
44
. header ( "ContentType" , "application/x-www-form-urlencoded" )
45
45
. body ( encoded)
46
46
. send ( )
47
- . await ?
47
+ . await
48
+ . map_err ( |e| DeviceCodeError :: RequestError ( Box :: new ( e) ) ) ?;
49
+
50
+ if !response. status ( ) . is_success ( ) {
51
+ return Err ( DeviceCodeError :: UnsuccessfulResponse (
52
+ response. status ( ) . as_u16 ( ) ,
53
+ response. text ( ) . await . ok ( ) ,
54
+ ) ) ;
55
+ }
56
+ let s = response
48
57
. text ( )
49
58
. await
50
- . map ( |s| -> Result < DeviceCodePhaseOneResponse , Error > {
51
- serde_json:: from_str :: < DeviceCodePhaseOneResponse > ( & s)
52
- // we need to capture some variables that will be useful in
53
- // the second phase (the client, the tenant_id and the client_id)
54
- . map ( |device_code_reponse| {
55
- Ok ( DeviceCodePhaseOneResponse {
56
- device_code : device_code_reponse. device_code ,
57
- user_code : device_code_reponse. user_code ,
58
- verification_uri : device_code_reponse. verification_uri ,
59
- expires_in : device_code_reponse. expires_in ,
60
- interval : device_code_reponse. interval ,
61
- message : device_code_reponse. message ,
62
- client : Some ( client) ,
63
- tenant_id,
64
- client_id : client_id. as_str ( ) . to_string ( ) ,
65
- } )
66
- } ) ?
67
- // TODO The HTTP status code should be checked to deserialize an error response.
68
- // serde_json::from_str::<crate::errors::ErrorResponse>(&s).map(Error::ErrorResponse)
69
- } ) ?
59
+ . map_err ( |e| DeviceCodeError :: RequestError ( Box :: new ( e) ) ) ?;
60
+
61
+ serde_json:: from_str :: < DeviceCodePhaseOneResponse > ( & s)
62
+ // we need to capture some variables that will be useful in
63
+ // the second phase (the client, the tenant_id and the client_id)
64
+ . map ( |device_code_reponse| {
65
+ Ok ( DeviceCodePhaseOneResponse {
66
+ device_code : device_code_reponse. device_code ,
67
+ user_code : device_code_reponse. user_code ,
68
+ verification_uri : device_code_reponse. verification_uri ,
69
+ expires_in : device_code_reponse. expires_in ,
70
+ interval : device_code_reponse. interval ,
71
+ message : device_code_reponse. message ,
72
+ client : Some ( client) ,
73
+ tenant_id,
74
+ client_id : client_id. as_str ( ) . to_string ( ) ,
75
+ } )
76
+ } )
77
+ . map_err ( |_| DeviceCodeError :: InvalidResponseBody ( s) ) ?
70
78
}
71
79
72
80
#[ derive( Debug , Clone , Deserialize ) ]
@@ -77,17 +85,14 @@ pub struct DeviceCodePhaseOneResponse<'a> {
77
85
expires_in : u64 ,
78
86
interval : u64 ,
79
87
message : String ,
80
- // the skipped fields below do not come
81
- // from the Azure answer. They will be added
82
- // manually after deserialization
88
+ // The skipped fields below do not come from the Azure answer.
89
+ // They will be added manually after deserialization
83
90
#[ serde( skip) ]
84
91
client : Option < & ' a reqwest:: Client > ,
85
92
#[ serde( skip) ]
86
93
tenant_id : Cow < ' a , str > ,
87
- // we store the ClientId as string instead of
88
- // the original type because it does not
89
- // implement Default and it's in another
90
- // create
94
+ // We store the ClientId as string instead of the original type, because it
95
+ // does not implement Default, and it's in another crate
91
96
#[ serde( skip) ]
92
97
client_id : String ,
93
98
}
@@ -97,9 +102,9 @@ impl<'a> DeviceCodePhaseOneResponse<'a> {
97
102
& self . message
98
103
}
99
104
100
- pub fn stream < ' b > (
101
- & ' b self ,
102
- ) -> impl futures:: Stream < Item = Result < DeviceCodeResponse , DeviceCodeError > > + ' b + ' _ {
105
+ pub fn stream (
106
+ & self ,
107
+ ) -> impl futures:: Stream < Item = Result < DeviceCodeResponse , DeviceCodeError > > + ' _ {
103
108
#[ derive( Debug , Clone , PartialEq ) ]
104
109
enum NextState {
105
110
Continue ,
@@ -114,12 +119,10 @@ impl<'a> DeviceCodePhaseOneResponse<'a> {
114
119
self . tenant_id,
115
120
) ;
116
121
117
- // throttle down as specified by Azure. This could be
122
+ // Throttle down as specified by Azure. This could be
118
123
// smarter: we could calculate the elapsed time since the
119
- // last poll and wait only the delta. For now we do not
120
- // need such precision.
124
+ // last poll and wait only the delta.
121
125
new_timer ( Duration :: from_secs ( self . interval ) ) . await ;
122
- debug ! ( "posting to {}" , & uri) ;
123
126
124
127
let mut encoded = form_urlencoded:: Serializer :: new ( String :: new ( ) ) ;
125
128
let encoded = encoded
@@ -136,22 +139,24 @@ impl<'a> DeviceCodePhaseOneResponse<'a> {
136
139
. body ( encoded)
137
140
. send ( )
138
141
. await
139
- . map_err ( DeviceCodeError :: ReqwestError )
142
+ . map_err ( |e| DeviceCodeError :: RequestError ( Box :: new ( e ) ) )
140
143
{
141
144
Ok ( result) => result,
142
145
Err ( error) => return Some ( ( Err ( error) , NextState :: Finish ) ) ,
143
146
} ;
144
- debug ! ( "result (raw) ==> {:?}" , result) ;
145
147
146
- let result = match result. text ( ) . await . map_err ( DeviceCodeError :: ReqwestError ) {
148
+ let result = match result
149
+ . text ( )
150
+ . await
151
+ . map_err ( |e| DeviceCodeError :: RequestError ( Box :: new ( e) ) )
152
+ {
147
153
Ok ( result) => result,
148
154
Err ( error) => return Some ( ( Err ( error) , NextState :: Finish ) ) ,
149
155
} ;
150
- debug ! ( "result (as text) ==> {}" , result) ;
151
156
152
- // here either we get an error response from Azure
157
+ // Here either we get an error response from Azure
153
158
// or we get a success. A success can be either "Pending" or
154
- // "Completed". We finish the loop only on "Completed" (ie Success)
159
+ // "Completed". We finish the loop only on "Completed"
155
160
match result. try_into ( ) {
156
161
Ok ( device_code_response) => {
157
162
let next_state = match & device_code_response {
0 commit comments