@@ -91,18 +91,18 @@ def __init__(self, params: CIBAAuthorizerParams[ToolInput], auth0: Auth0ClientPa
91
91
self .credentials_store = SubStore [TokenResponse ](ciba_store , {
92
92
"get_ttl" : lambda credential : credential ["expires_in" ] * 1000 if "expires_in" in credential else None
93
93
})
94
-
94
+
95
95
def _handle_authorization_interrupts (self , err : Union [AuthorizationPendingInterrupt , AuthorizationPollingInterrupt ]) -> None :
96
96
raise err
97
-
97
+
98
98
def _get_instance_id (self , authorize_params ) -> str :
99
99
props = {
100
100
"auth0" : omit (self .auth0 , ["client_secret" , "client_assertion_signing_key" ]),
101
101
"params" : authorize_params
102
102
}
103
103
sh = json .dumps (props , sort_keys = True , separators = ("," , ":" ))
104
104
return hashlib .md5 (sh .encode ("utf-8" )).hexdigest ()
105
-
105
+
106
106
async def _get_authorize_params (self , * args : ToolInput .args , ** kwargs : ToolInput .kwargs ) -> Dict [str , Any ]:
107
107
authorize_params = {
108
108
"scope" : _ensure_openid_scope (self .params .get ("scope" )),
@@ -129,11 +129,18 @@ async def _get_authorize_params(self, *args: ToolInput.args, **kwargs: ToolInput
129
129
else :
130
130
authorize_params ["binding_message" ] = self .params .get ("binding_message" )(* args , ** kwargs )
131
131
132
+ if isinstance (self .params .get ("authorization_details" ), list ):
133
+ authorize_params ["authorization_details" ] = self .params .get ("authorization_details" )
134
+ elif inspect .iscoroutinefunction (self .params .get ("authorization_details" )):
135
+ authorize_params ["authorization_details" ] = await self .params .get ("authorization_details" )(* args , ** kwargs )
136
+ else :
137
+ authorize_params ["authorization_details" ] = self .params .get ("authorization_details" )(* args , ** kwargs )
138
+
132
139
return authorize_params
133
140
134
141
async def _start (self , authorize_params ) -> CIBAAuthorizationRequest :
135
142
requested_at = time .time ()
136
-
143
+
137
144
try :
138
145
response = self .back_channel_login .back_channel_login (** authorize_params )
139
146
return CIBAAuthorizationRequest (
@@ -189,7 +196,7 @@ def _get_credentials_internal(self, auth_request: CIBAAuthorizationRequest) -> T
189
196
190
197
def _get_credentials (self , auth_request : CIBAAuthorizationRequest ) -> TokenResponse | None :
191
198
return self ._get_credentials_internal (auth_request )
192
-
199
+
193
200
async def get_credentials_polling (self , auth_request : CIBAAuthorizationRequest ) -> TokenResponse | None :
194
201
credentials : TokenResponse | None = None
195
202
@@ -200,14 +207,14 @@ async def get_credentials_polling(self, auth_request: CIBAAuthorizationRequest)
200
207
await asyncio .sleep (err .request ["interval" ])
201
208
except Exception :
202
209
raise
203
-
210
+
204
211
return credentials
205
-
212
+
206
213
async def delete_auth_request (self ):
207
214
local_store = _get_local_storage ()
208
215
auth_request_ns = local_store ["auth_request_ns" ]
209
216
await self .auth_request_store .delete (auth_request_ns , "auth_request" )
210
-
217
+
211
218
def protect (
212
219
self ,
213
220
get_context : ContextGetter [ToolInput ],
@@ -237,28 +244,28 @@ async def wrapped_execute(*args: ToolInput.args, **kwargs: ToolInput.kwargs):
237
244
# initial request
238
245
auth_request = await self ._start (authorize_params )
239
246
await self .auth_request_store .put (auth_request_ns , "auth_request" , auth_request )
240
-
247
+
241
248
credentials = self ._get_credentials (auth_request )
242
249
else :
243
250
# block mode
244
251
auth_request = await self ._start (authorize_params )
245
252
credentials = await self .get_credentials_polling (auth_request )
246
253
247
254
await self .delete_auth_request ()
248
-
255
+
249
256
if credentials is not None :
250
257
await self .credentials_store .put (credentials_ns , "credential" , credentials )
251
258
except (AuthorizationPendingInterrupt , AuthorizationPollingInterrupt ) as interrupt :
252
259
return self ._handle_authorization_interrupts (interrupt )
253
260
except Exception as err :
254
261
await self .delete_auth_request ()
255
262
raise
256
-
263
+
257
264
_update_local_storage ({"credentials" : credentials })
258
265
259
266
if inspect .iscoroutinefunction (execute ):
260
267
return await execute (* args , ** kwargs )
261
268
else :
262
269
return execute (* args , ** kwargs )
263
-
270
+
264
271
return wrapped_execute
0 commit comments