diff --git a/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/blobClient.py b/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/blobClient.py index c5a2bdf..71a92f3 100644 --- a/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/blobClient.py +++ b/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/blobClient.py @@ -3,10 +3,10 @@ import json -from azure.identity import DefaultAzureCredential -from azure.storage.blob import BlobServiceClient from azurefunctions.extensions.base import Datum, SdkType -from .utils import get_connection_string, using_managed_identity +from .utils import (using_system_managed_identity, + using_user_managed_identity, + get_blob_service_client) class BlobClient(SdkType): @@ -25,9 +25,12 @@ def __init__(self, *, data: Datum) -> None: self._source = data.source self._content_type = data.content_type content_json = json.loads(data.content) - self._connection = get_connection_string(content_json.get("Connection")) - self._using_managed_identity = using_managed_identity( - content_json.get("Connection") + self._connection = content_json.get("Connection") + self._system_managed_identity = using_system_managed_identity( + self._connection + ) + self._user_managed_identity = using_user_managed_identity( + self._connection ) self._containerName = content_json.get("ContainerName") self._blobName = content_json.get("BlobName") @@ -38,19 +41,17 @@ def get_sdk_type(self): through a BlobServiceClient. There are two ways to create a BlobServiceClient: 1. Through the constructor: this is the only option when using Managed Identity + 1a. If system-based MI, the credential is DefaultAzureCredential + 1b. If user-based MI, the credential is ManagedIdentityCredential 2. Through from_connection_string: this is the only option when not using Managed Identity We track if Managed Identity is being used through a flag. """ if self._data: - blob_service_client = ( - BlobServiceClient( - account_url=self._connection, credential=DefaultAzureCredential() - ) - if self._using_managed_identity - else BlobServiceClient.from_connection_string(self._connection) - ) + blob_service_client = get_blob_service_client(self._system_managed_identity, + self._user_managed_identity, + self._connection) return blob_service_client.get_blob_client( container=self._containerName, blob=self._blobName, diff --git a/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/containerClient.py b/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/containerClient.py index 8f83bd4..552a028 100644 --- a/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/containerClient.py +++ b/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/containerClient.py @@ -3,10 +3,10 @@ import json -from azure.identity import DefaultAzureCredential -from azure.storage.blob import BlobServiceClient from azurefunctions.extensions.base import Datum, SdkType -from .utils import get_connection_string, using_managed_identity +from .utils import (using_system_managed_identity, + using_user_managed_identity, + get_blob_service_client) class ContainerClient(SdkType): @@ -25,9 +25,12 @@ def __init__(self, *, data: Datum) -> None: self._source = data.source self._content_type = data.content_type content_json = json.loads(data.content) - self._connection = get_connection_string(content_json.get("Connection")) - self._using_managed_identity = using_managed_identity( - content_json.get("Connection") + self._connection = content_json.get("Connection") + self._system_managed_identity = using_system_managed_identity( + self._connection + ) + self._user_managed_identity = using_user_managed_identity( + self._connection ) self._containerName = content_json.get("ContainerName") self._blobName = content_json.get("BlobName") @@ -35,13 +38,9 @@ def __init__(self, *, data: Datum) -> None: # Returns a ContainerClient def get_sdk_type(self): if self._data: - blob_service_client = ( - BlobServiceClient( - account_url=self._connection, credential=DefaultAzureCredential() - ) - if self._using_managed_identity - else BlobServiceClient.from_connection_string(self._connection) - ) + blob_service_client = get_blob_service_client(self._system_managed_identity, + self._user_managed_identity, + self._connection) return blob_service_client.get_container_client( container=self._containerName ) diff --git a/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/storageStreamDownloader.py b/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/storageStreamDownloader.py index a21c3ca..6cf58f3 100644 --- a/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/storageStreamDownloader.py +++ b/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/storageStreamDownloader.py @@ -3,10 +3,10 @@ import json -from azure.identity import DefaultAzureCredential -from azure.storage.blob import BlobServiceClient from azurefunctions.extensions.base import Datum, SdkType -from .utils import get_connection_string, using_managed_identity +from .utils import (using_system_managed_identity, + using_user_managed_identity, + get_blob_service_client) class StorageStreamDownloader(SdkType): @@ -25,9 +25,12 @@ def __init__(self, *, data: Datum) -> None: self._source = data.source self._content_type = data.content_type content_json = json.loads(data.content) - self._connection = get_connection_string(content_json.get("Connection")) - self._using_managed_identity = using_managed_identity( - content_json.get("Connection") + self._connection = content_json.get("Connection") + self._system_managed_identity = using_system_managed_identity( + self._connection + ) + self._user_managed_identity = using_user_managed_identity( + self._connection ) self._containerName = content_json.get("ContainerName") self._blobName = content_json.get("BlobName") @@ -35,14 +38,9 @@ def __init__(self, *, data: Datum) -> None: # Returns a StorageStreamDownloader def get_sdk_type(self): if self._data: - blob_service_client = ( - BlobServiceClient( - account_url=self._connection, credential=DefaultAzureCredential() - ) - if self._using_managed_identity - else BlobServiceClient.from_connection_string(self._connection) - ) - # download_blob() returns a StorageStreamDownloader object + blob_service_client = get_blob_service_client(self._system_managed_identity, + self._user_managed_identity, + self._connection) return blob_service_client.get_blob_client( container=self._containerName, blob=self._blobName, diff --git a/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/utils.py b/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/utils.py index 5421061..3c99862 100644 --- a/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/utils.py +++ b/azurefunctions-extensions-bindings-blob/azurefunctions/extensions/bindings/blob/utils.py @@ -2,6 +2,9 @@ # Licensed under the MIT License. import os +from azure.identity import DefaultAzureCredential, ManagedIdentityCredential +from azure.storage.blob import BlobServiceClient + def get_connection_string(connection_string: str) -> str: """ @@ -39,12 +42,45 @@ def get_connection_string(connection_string: str) -> str: ) -def using_managed_identity(connection_name: str) -> bool: +def using_system_managed_identity(connection_name: str) -> bool: """ - To determine if managed identity is being used, we check if the provided - connection string has either of the two suffixes: + To determine if system-assigned managed identity is being used, we check if + the provided connection string has either of the two suffixes: __serviceUri or __blobServiceUri. """ return (os.getenv(connection_name + "__serviceUri") is not None) or ( os.getenv(connection_name + "__blobServiceUri") is not None ) + + +def using_user_managed_identity(connection_name: str) -> bool: + """ + To determine if user-assigned managed identity is being used, we check if + the provided connection string has the following suffixes: + __serviceUri and __credential, AND either __managedIdentityResourceId + or __clientID. + + We are not verifying that the customer only has __managedIdentityResourceId + or __clientID. That check is handled by the (???). + """ + return ((os.getenv(connection_name + "__serviceUri") is not None) + and (os.getenv(connection_name + "__credential") is not None) + and ((os.getenv(connection_name + "__managedIdentityResourceId") + is not None) + or (os.getenv(connection_name + "__clientID") is not None + ))) + + +def get_blob_service_client(system_managed_identity: bool, + user_managed_identity: bool, + connection: str): + connection_string = get_connection_string(connection) + if user_managed_identity: + return BlobServiceClient(account_url=connection_string, + credential=ManagedIdentityCredential( + client_id=os.getenv(connection + "__clientID"))) + elif system_managed_identity: + return BlobServiceClient(account_url=connection_string, + credential=DefaultAzureCredential()) + else: + return BlobServiceClient.from_connection_string(connection_string) diff --git a/azurefunctions-extensions-bindings-blob/tests/test_blobclient.py b/azurefunctions-extensions-bindings-blob/tests/test_blobclient.py index 68b66aa..ec29fbc 100644 --- a/azurefunctions-extensions-bindings-blob/tests/test_blobclient.py +++ b/azurefunctions-extensions-bindings-blob/tests/test_blobclient.py @@ -93,11 +93,15 @@ def test_input_incorrect_type(self): ) def test_input_empty(self): - datum: Datum = Datum(value={}, type="model_binding_data") - with self.assertRaises(ValueError): - BlobClientConverter.decode( + with self.assertRaises(ValueError) as e: + datum: Datum = Datum(value={}, type="model_binding_data") + _: BlobClient = BlobClientConverter.decode( data=datum, trigger_metadata=None, pytype=BlobClient ) + self.assertEqual( + e.exception.args[0], + "ValueError: Unable to create BlobClient SDK type.", + ) def test_input_populated(self): content = {