Skip to content

Commit 22f9f00

Browse files
committed
feat(ciba): support RAR requests
1 parent f1a8531 commit 22f9f00

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

packages/auth0-ai/auth0_ai/authorizers/ciba/ciba_authorizer_base.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,18 @@ def __init__(self, params: CIBAAuthorizerParams[ToolInput], auth0: Auth0ClientPa
9191
self.credentials_store = SubStore[TokenResponse](ciba_store, {
9292
"get_ttl": lambda credential: credential["expires_in"] * 1000 if "expires_in" in credential else None
9393
})
94-
94+
9595
def _handle_authorization_interrupts(self, err: Union[AuthorizationPendingInterrupt, AuthorizationPollingInterrupt]) -> None:
9696
raise err
97-
97+
9898
def _get_instance_id(self, authorize_params) -> str:
9999
props = {
100100
"auth0": omit(self.auth0, ["client_secret", "client_assertion_signing_key"]),
101101
"params": authorize_params
102102
}
103103
sh = json.dumps(props, sort_keys=True, separators=(",", ":"))
104104
return hashlib.md5(sh.encode("utf-8")).hexdigest()
105-
105+
106106
async def _get_authorize_params(self, *args: ToolInput.args, **kwargs: ToolInput.kwargs) -> Dict[str, Any]:
107107
authorize_params = {
108108
"scope": _ensure_openid_scope(self.params.get("scope")),
@@ -129,11 +129,18 @@ async def _get_authorize_params(self, *args: ToolInput.args, **kwargs: ToolInput
129129
else:
130130
authorize_params["binding_message"] = self.params.get("binding_message")(*args, **kwargs)
131131

132+
if isinstance(self.params.get("authorization_details"), list):
133+
authorize_params["authorization_details"] = self.params.get("authorization_details")
134+
elif inspect.iscoroutinefunction(self.params.get("authorization_details")):
135+
authorize_params["authorization_details"] = await self.params.get("authorization_details")(*args, **kwargs)
136+
else:
137+
authorize_params["authorization_details"] = self.params.get("authorization_details")(*args, **kwargs)
138+
132139
return authorize_params
133140

134141
async def _start(self, authorize_params) -> CIBAAuthorizationRequest:
135142
requested_at = time.time()
136-
143+
137144
try:
138145
response = self.back_channel_login.back_channel_login(**authorize_params)
139146
return CIBAAuthorizationRequest(
@@ -189,7 +196,7 @@ def _get_credentials_internal(self, auth_request: CIBAAuthorizationRequest) -> T
189196

190197
def _get_credentials(self, auth_request: CIBAAuthorizationRequest) -> TokenResponse | None:
191198
return self._get_credentials_internal(auth_request)
192-
199+
193200
async def get_credentials_polling(self, auth_request: CIBAAuthorizationRequest) -> TokenResponse | None:
194201
credentials: TokenResponse | None = None
195202

@@ -200,14 +207,14 @@ async def get_credentials_polling(self, auth_request: CIBAAuthorizationRequest)
200207
await asyncio.sleep(err.request["interval"])
201208
except Exception:
202209
raise
203-
210+
204211
return credentials
205-
212+
206213
async def delete_auth_request(self):
207214
local_store = _get_local_storage()
208215
auth_request_ns = local_store["auth_request_ns"]
209216
await self.auth_request_store.delete(auth_request_ns, "auth_request")
210-
217+
211218
def protect(
212219
self,
213220
get_context: ContextGetter[ToolInput],
@@ -237,28 +244,28 @@ async def wrapped_execute(*args: ToolInput.args, **kwargs: ToolInput.kwargs):
237244
# initial request
238245
auth_request = await self._start(authorize_params)
239246
await self.auth_request_store.put(auth_request_ns, "auth_request", auth_request)
240-
247+
241248
credentials = self._get_credentials(auth_request)
242249
else:
243250
# block mode
244251
auth_request = await self._start(authorize_params)
245252
credentials = await self.get_credentials_polling(auth_request)
246253

247254
await self.delete_auth_request()
248-
255+
249256
if credentials is not None:
250257
await self.credentials_store.put(credentials_ns, "credential", credentials)
251258
except (AuthorizationPendingInterrupt, AuthorizationPollingInterrupt) as interrupt:
252259
return self._handle_authorization_interrupts(interrupt)
253260
except Exception as err:
254261
await self.delete_auth_request()
255262
raise
256-
263+
257264
_update_local_storage({"credentials": credentials})
258265

259266
if inspect.iscoroutinefunction(execute):
260267
return await execute(*args, **kwargs)
261268
else:
262269
return execute(*args, **kwargs)
263-
270+
264271
return wrapped_execute

packages/auth0-ai/auth0_ai/authorizers/ciba/ciba_authorizer_params.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,18 @@
33
from auth0_ai.authorizers.types import ToolInput
44
from auth0_ai.stores import Store
55

6+
67
class CIBAAuthorizerParams(TypedDict, Generic[ToolInput]):
78
"""
89
Authorize Options to start CIBA flow.
910
1011
Attributes:
1112
scope (list[str]): The scopes to request authorization for.
1213
binding_message (Union[str, Callable[..., Awaitable[str]]]): A human-readable string to display to the user, or a function that resolves it.
14+
authorization_details (Union[list[dict], Callable[..., Awaitable[list[dict]]]]):
15+
Authorization details that specify what the user is authorizing. Can be:
16+
- A list of dictionaries with authorization details (e.g., [{ type: "custom_type", param: "example", ...}] details)
17+
- A function that resolves to a list of dictionaries
1318
user_id (Union[str, Callable[..., Awaitable[str]]]): The user id string, or a function that resolves it.
1419
store (Store, optional): An store used to temporarly store the authorization response data while the user is completing the authorization in another device (default: InMemoryStore).
1520
audience (str, optional): The audience to request authorization for.
@@ -25,6 +30,7 @@ class CIBAAuthorizerParams(TypedDict, Generic[ToolInput]):
2530
"""
2631
scopes: list[str]
2732
binding_message: Union[str, Callable[ToolInput, Awaitable[str]]]
33+
authorization_details: Union[list[dict], Callable[ToolInput, Awaitable[list[dict]]]]
2834
user_id: Union[str, Callable[ToolInput, Awaitable[str]]]
2935
store: Optional[Store]
3036
audience: Optional[str]

0 commit comments

Comments
 (0)