Skip to content

build: Add User-Based MI #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import json

from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
from azurefunctions.extensions.base import Datum, SdkType
from .utils import get_connection_string, using_managed_identity
from .utils import (using_system_managed_identity,
using_user_managed_identity,
get_blob_service_client)


class BlobClient(SdkType):
Expand All @@ -25,9 +25,12 @@ def __init__(self, *, data: Datum) -> None:
self._source = data.source
self._content_type = data.content_type
content_json = json.loads(data.content)
self._connection = get_connection_string(content_json.get("Connection"))
self._using_managed_identity = using_managed_identity(
content_json.get("Connection")
self._connection = content_json.get("Connection")
self._system_managed_identity = using_system_managed_identity(
self._connection
)
self._user_managed_identity = using_user_managed_identity(
self._connection
)
self._containerName = content_json.get("ContainerName")
self._blobName = content_json.get("BlobName")
Expand All @@ -38,19 +41,17 @@ def get_sdk_type(self):
through a BlobServiceClient. There are two ways to create a
BlobServiceClient:
1. Through the constructor: this is the only option when using Managed Identity
1a. If system-based MI, the credential is DefaultAzureCredential
1b. If user-based MI, the credential is ManagedIdentityCredential
2. Through from_connection_string: this is the only option when
not using Managed Identity

We track if Managed Identity is being used through a flag.
"""
if self._data:
blob_service_client = (
BlobServiceClient(
account_url=self._connection, credential=DefaultAzureCredential()
)
if self._using_managed_identity
else BlobServiceClient.from_connection_string(self._connection)
)
blob_service_client = get_blob_service_client(self._system_managed_identity,
self._user_managed_identity,
self._connection)
return blob_service_client.get_blob_client(
container=self._containerName,
blob=self._blobName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import json

from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
from azurefunctions.extensions.base import Datum, SdkType
from .utils import get_connection_string, using_managed_identity
from .utils import (using_system_managed_identity,
using_user_managed_identity,
get_blob_service_client)


class ContainerClient(SdkType):
Expand All @@ -25,23 +25,22 @@ def __init__(self, *, data: Datum) -> None:
self._source = data.source
self._content_type = data.content_type
content_json = json.loads(data.content)
self._connection = get_connection_string(content_json.get("Connection"))
self._using_managed_identity = using_managed_identity(
content_json.get("Connection")
self._connection = content_json.get("Connection")
self._system_managed_identity = using_system_managed_identity(
self._connection
)
self._user_managed_identity = using_user_managed_identity(
self._connection
)
self._containerName = content_json.get("ContainerName")
self._blobName = content_json.get("BlobName")

# Returns a ContainerClient
def get_sdk_type(self):
if self._data:
blob_service_client = (
BlobServiceClient(
account_url=self._connection, credential=DefaultAzureCredential()
)
if self._using_managed_identity
else BlobServiceClient.from_connection_string(self._connection)
)
blob_service_client = get_blob_service_client(self._system_managed_identity,
self._user_managed_identity,
self._connection)
return blob_service_client.get_container_client(
container=self._containerName
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import json

from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
from azurefunctions.extensions.base import Datum, SdkType
from .utils import get_connection_string, using_managed_identity
from .utils import (using_system_managed_identity,
using_user_managed_identity,
get_blob_service_client)


class StorageStreamDownloader(SdkType):
Expand All @@ -25,24 +25,22 @@ def __init__(self, *, data: Datum) -> None:
self._source = data.source
self._content_type = data.content_type
content_json = json.loads(data.content)
self._connection = get_connection_string(content_json.get("Connection"))
self._using_managed_identity = using_managed_identity(
content_json.get("Connection")
self._connection = content_json.get("Connection")
self._system_managed_identity = using_system_managed_identity(
self._connection
)
self._user_managed_identity = using_user_managed_identity(
self._connection
)
self._containerName = content_json.get("ContainerName")
self._blobName = content_json.get("BlobName")

# Returns a StorageStreamDownloader
def get_sdk_type(self):
if self._data:
blob_service_client = (
BlobServiceClient(
account_url=self._connection, credential=DefaultAzureCredential()
)
if self._using_managed_identity
else BlobServiceClient.from_connection_string(self._connection)
)
# download_blob() returns a StorageStreamDownloader object
blob_service_client = get_blob_service_client(self._system_managed_identity,
self._user_managed_identity,
self._connection)
return blob_service_client.get_blob_client(
container=self._containerName,
blob=self._blobName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# Licensed under the MIT License.
import os

from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
from azure.storage.blob import BlobServiceClient


def get_connection_string(connection_string: str) -> str:
"""
Expand Down Expand Up @@ -39,12 +42,45 @@ def get_connection_string(connection_string: str) -> str:
)


def using_managed_identity(connection_name: str) -> bool:
def using_system_managed_identity(connection_name: str) -> bool:
"""
To determine if managed identity is being used, we check if the provided
connection string has either of the two suffixes:
To determine if system-assigned managed identity is being used, we check if
the provided connection string has either of the two suffixes:
__serviceUri or __blobServiceUri.
"""
return (os.getenv(connection_name + "__serviceUri") is not None) or (
os.getenv(connection_name + "__blobServiceUri") is not None
)


def using_user_managed_identity(connection_name: str) -> bool:
"""
To determine if user-assigned managed identity is being used, we check if
the provided connection string has the following suffixes:
__serviceUri and __credential, AND either __managedIdentityResourceId
or __clientID.

We are not verifying that the customer only has __managedIdentityResourceId
or __clientID. That check is handled by the (???).
"""
return ((os.getenv(connection_name + "__serviceUri") is not None)
and (os.getenv(connection_name + "__credential") is not None)
and ((os.getenv(connection_name + "__managedIdentityResourceId")
is not None)
or (os.getenv(connection_name + "__clientID") is not None
)))


def get_blob_service_client(system_managed_identity: bool,
user_managed_identity: bool,
connection: str):
connection_string = get_connection_string(connection)
if user_managed_identity:
return BlobServiceClient(account_url=connection_string,
credential=ManagedIdentityCredential(
client_id=os.getenv(connection + "__clientID")))
elif system_managed_identity:
return BlobServiceClient(account_url=connection_string,
credential=DefaultAzureCredential())
else:
return BlobServiceClient.from_connection_string(connection_string)
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,15 @@ def test_input_incorrect_type(self):
)

def test_input_empty(self):
datum: Datum = Datum(value={}, type="model_binding_data")
with self.assertRaises(ValueError):
BlobClientConverter.decode(
with self.assertRaises(ValueError) as e:
datum: Datum = Datum(value={}, type="model_binding_data")
_: BlobClient = BlobClientConverter.decode(
data=datum, trigger_metadata=None, pytype=BlobClient
)
self.assertEqual(
e.exception.args[0],
"ValueError: Unable to create BlobClient SDK type.",
)

def test_input_populated(self):
content = {
Expand Down
Loading