Skip to content

feat(federated-connection): add missing method get_access_token_for_connection #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@
from github.GithubException import BadCredentialsException
from llama_index.core.tools import FunctionTool

from auth0_ai_llamaindex.federated_connections import FederatedConnectionError, get_credentials_for_connection
from auth0_ai_llamaindex.federated_connections import FederatedConnectionError, get_access_token_for_connection
from src.auth0.auth0_ai import with_github_access


def list_repositories_tool_function():
credentials = get_credentials_for_connection()
if not credentials:
access_token = get_access_token_for_connection()
if not access_token:
raise ValueError(
"Authorization required to access the Federated Connection API")

# GitHub SDK
try:
g = Github(credentials["access_token"])
g = Github(access_token)
user = g.get_user()
repos = user.get_repos()
repo_names = [repo.name for repo in repos]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,8 @@
FederatedConnectionInterrupt as FederatedConnectionInterrupt
)

from auth0_ai.authorizers.federated_connection_authorizer import get_credentials_for_connection as get_credentials_for_connection
from auth0_ai.authorizers.federated_connection_authorizer import (
get_credentials_for_connection as get_credentials_for_connection,
get_access_token_for_connection as get_access_token_for_connection
)
from .federated_connection_authorizer import FederatedConnectionAuthorizer as FederatedConnectionAuthorizer
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,8 @@
FederatedConnectionInterrupt as FederatedConnectionInterrupt
)

from auth0_ai.authorizers.federated_connection_authorizer import get_credentials_for_connection as get_credentials_for_connection
from auth0_ai.authorizers.federated_connection_authorizer import (
get_credentials_for_connection as get_credentials_for_connection,
get_access_token_for_connection as get_access_token_for_connection
)
from auth0_ai_llamaindex.federated_connections.federated_connection_authorizer import FederatedConnectionAuthorizer as FederatedConnectionAuthorizer
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def get_credentials_for_connection() -> TokenResponse | None:
store = _get_local_storage()
return store.get("credentials")

def get_access_token_for_connection() -> str | None:
store = _get_local_storage()
return store.get("credentials", {}).get("access_token")

class FederatedConnectionAuthorizerParams(Generic[ToolInput]):
def __init__(
self,
Expand Down Expand Up @@ -93,7 +97,7 @@ def wrap(val, result_type):
if isinstance(val, AuthorizerToolParameter):
return val
return AuthorizerToolParameter[ToolInput, result_type](val)

self.scopes = scopes
self.connection = connection
self.refresh_token = wrap(refresh_token, str | None)
Expand Down Expand Up @@ -136,26 +140,26 @@ def __init__(
# Ensure either refreshToken or accessToken is provided
if params.refresh_token.value is None and params.access_token.value is None:
raise ValueError("Either refresh_token or access_token must be provided to initialize the Authorizer.")

if params.refresh_token.value is not None and params.access_token.value is not None:
raise ValueError("Only one of refresh_token or access_token can be provided to initialize the Authorizer.")

def _handle_authorization_interrupts(self, err: Auth0Interrupt) -> None:
raise err

def _get_instance_id(self) -> str:
props = {
"auth0": omit(self.auth0, ["client_secret", "client_assertion_signing_key"]),
"params": omit(self.params, ["store", "refresh_token", "access_token"])
}
sh = json.dumps(props, sort_keys=True, separators=(",", ":"))
return hashlib.md5(sh.encode("utf-8")).hexdigest()

def validate_token(self, token_response: Optional[TokenResponse] = None):
store = _get_local_storage()
scopes = store["scopes"]
connection = store["connection"]

if token_response is None:
raise FederatedConnectionInterrupt(
f"Authorization required to access the Federated Connection API: {connection}",
Expand All @@ -167,7 +171,7 @@ def validate_token(self, token_response: Optional[TokenResponse] = None):
current_scopes = token_response["scope"]
missing_scopes = [s for s in scopes if s not in current_scopes]
_update_local_storage({"current_scopes": current_scopes})

if missing_scopes:
raise FederatedConnectionInterrupt(
f"Authorization required to access the Federated Connection API: {connection}. Missing scopes: {', '.join(missing_scopes)}",
Expand All @@ -178,20 +182,20 @@ def validate_token(self, token_response: Optional[TokenResponse] = None):

async def get_access_token_impl(self, *args: ToolInput.args, **kwargs: ToolInput.kwargs) -> TokenResponse | None:
store = _get_local_storage()

connection = store["connection"]
subject_token = await self.get_refresh_token(*args, **kwargs)
if not subject_token:
return None

try:
response = self.get_token.access_token_for_connection(
subject_token_type="urn:ietf:params:oauth:token-type:refresh_token",
subject_token=subject_token,
requested_token_type="http://auth0.com/oauth/token-type/federated-connection-access-token",
connection=connection,
)

return TokenResponse(
access_token=response["access_token"],
expires_in=response["expires_in"],
Expand All @@ -202,19 +206,19 @@ async def get_access_token_impl(self, *args: ToolInput.args, **kwargs: ToolInput
)
except Auth0Error as err:
raise FederatedConnectionError(err.message) if 400 <= err.status_code <= 499 else err

async def get_access_token(self, *args: ToolInput.args, **kwargs: ToolInput.kwargs) -> TokenResponse | None:
if callable(self.params.refresh_token.value) or asyncio.iscoroutinefunction(self.params.refresh_token.value):
token_response = await self.get_access_token_impl(*args, **kwargs)
else:
token_response = await self.params.access_token.resolve(*args, **kwargs)

self.validate_token(token_response)
return token_response

async def get_refresh_token(self, *args: ToolInput.args, **kwargs: ToolInput.kwargs):
return await self.params.refresh_token.resolve(*args, **kwargs)

def protect(
self,
get_context: ContextGetter[ToolInput],
Expand All @@ -233,11 +237,11 @@ async def wrapped_execute(*args: ToolInput.args, **kwargs: ToolInput.kwargs):

try:
credentials = await self.credentials_store.get(credentials_ns, "credential")

if not credentials:
credentials = await self.get_access_token(*args, **kwargs)
await self.credentials_store.put(credentials_ns, "credential", credentials)

_update_local_storage({"credentials": credentials})

if inspect.iscoroutinefunction(execute):
Expand All @@ -256,5 +260,5 @@ async def wrapped_execute(*args: ToolInput.args, **kwargs: ToolInput.kwargs):
except Auth0Interrupt as err:
self.credentials_store.delete(credentials_ns, "credential")
return self._handle_authorization_interrupts(err)

return wrapped_execute