Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 354b16b

Browse files
committedMar 5, 2024·
pluggable azure credentials provider
1 parent 2574095 commit 354b16b

File tree

7 files changed

+75
-43
lines changed

7 files changed

+75
-43
lines changed
 

‎metaflow/extension_support/plugins.py

+1
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def resolve_plugins(category):
179179
"metadata_provider": lambda x: x.TYPE,
180180
"datastore": lambda x: x.TYPE,
181181
"secrets_provider": lambda x: x.TYPE,
182+
"azure_client_provider": lambda x: x.name,
182183
"sidecar": None,
183184
"logging_sidecar": None,
184185
"monitor_sidecar": None,

‎metaflow/plugins/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@
122122
),
123123
]
124124

125+
AZURE_CLIENT_PROVIDERS_DESC = [
126+
("azure-default", ".azure.azure_credential.AzureDefaultClientProvider")
127+
]
128+
129+
125130
process_plugins(globals())
126131

127132

@@ -143,6 +148,7 @@ def get_plugin_cli():
143148

144149
AWS_CLIENT_PROVIDERS = resolve_plugins("aws_client_provider")
145150
SECRETS_PROVIDERS = resolve_plugins("secrets_provider")
151+
AZURE_CLIENT_PROVIDERS = resolve_plugins("azure_client_provider")
146152

147153
from .cards.card_modules import MF_EXTERNAL_CARDS
148154

‎metaflow/plugins/azure/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .azure_credential import (
2+
create_cacheable_azure_credential as create_azure_credential,
3+
)
+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
class AzureDefaultClientProvider(object):
2+
name = "azure-default"
3+
4+
@staticmethod
5+
def create_cacheable_azure_credential(*args, **kwargs):
6+
"""azure.identity.DefaultAzureCredential is not readily cacheable in a dictionary
7+
because it does not have a content based hash and equality implementations.
8+
9+
We implement a subclass CacheableDefaultAzureCredential to add them.
10+
11+
We need this because credentials will be part of the cache key in _ClientCache.
12+
"""
13+
from azure.identity import DefaultAzureCredential
14+
15+
class CacheableDefaultAzureCredential(DefaultAzureCredential):
16+
def __init__(self, *args, **kwargs):
17+
super(CacheableDefaultAzureCredential, self).__init__(*args, **kwargs)
18+
# Just hashing all the kwargs works because they are all individually
19+
# hashable as of 7/15/2022.
20+
#
21+
# What if Azure adds unhashable things to kwargs?
22+
# - We will have CI to catch this (it will always install the latest Azure SDKs)
23+
# - In Metaflow usage today we never specify any kwargs anyway. (see last line
24+
# of the outer function.
25+
self._hash_code = hash((args, tuple(sorted(kwargs.items()))))
26+
27+
def __hash__(self):
28+
return self._hash_code
29+
30+
def __eq__(self, other):
31+
return hash(self) == hash(other)
32+
33+
return CacheableDefaultAzureCredential(*args, **kwargs)
34+
35+
36+
cached_provider_class = None
37+
38+
39+
def create_cacheable_azure_credential():
40+
global cached_provider_class
41+
if cached_provider_class is None:
42+
from metaflow.metaflow_config import DEFAULT_AZURE_CLIENT_PROVIDER
43+
from metaflow.plugins import AZURE_CLIENT_PROVIDERS
44+
45+
for p in AZURE_CLIENT_PROVIDERS:
46+
if p.name == DEFAULT_AZURE_CLIENT_PROVIDER:
47+
cached_provider_class = p
48+
break
49+
else:
50+
raise ValueError(
51+
"Cannot find Azure Client provider %s" % DEFAULT_AZURE_CLIENT_PROVIDER
52+
)
53+
return cached_provider_class.create_cacheable_azure_credential()

‎metaflow/plugins/azure/azure_utils.py

+2-35
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
MetaflowAzurePackageError,
88
)
99
from metaflow.exception import MetaflowInternalError, MetaflowException
10+
from metaflow.plugins.azure.azure_credential import create_cacheable_azure_credential
1011

1112

1213
def _check_and_init_azure_deps():
@@ -138,38 +139,6 @@ def _inner_func(*args, **kwargs):
138139
return _inner_func
139140

140141

141-
@check_azure_deps
142-
def create_cacheable_default_azure_credentials(*args, **kwargs):
143-
"""azure.identity.DefaultAzureCredential is not readily cacheable in a dictionary
144-
because it does not have a content based hash and equality implementations.
145-
146-
We implement a subclass CacheableDefaultAzureCredential to add them.
147-
148-
We need this because credentials will be part of the cache key in _ClientCache.
149-
"""
150-
from azure.identity import DefaultAzureCredential
151-
152-
class CacheableDefaultAzureCredential(DefaultAzureCredential):
153-
def __init__(self, *args, **kwargs):
154-
super(CacheableDefaultAzureCredential, self).__init__(*args, **kwargs)
155-
# Just hashing all the kwargs works because they are all individually
156-
# hashable as of 7/15/2022.
157-
#
158-
# What if Azure adds unhashable things to kwargs?
159-
# - We will have CI to catch this (it will always install the latest Azure SDKs)
160-
# - In Metaflow usage today we never specify any kwargs anyway. (see last line
161-
# of the outer function.
162-
self._hash_code = hash((args, tuple(sorted(kwargs.items()))))
163-
164-
def __hash__(self):
165-
return self._hash_code
166-
167-
def __eq__(self, other):
168-
return hash(self) == hash(other)
169-
170-
return CacheableDefaultAzureCredential(*args, **kwargs)
171-
172-
173142
@check_azure_deps
174143
def create_static_token_credential(token_):
175144
from azure.core.credentials import TokenCredential
@@ -200,9 +169,7 @@ def __init__(self, token):
200169
def get_token(self, *_scopes, **_kwargs):
201170

202171
if (self._cached_token.expires_on - time.time()) < 300:
203-
from azure.identity import DefaultAzureCredential
204-
205-
self._credential = DefaultAzureCredential()
172+
self._credential = create_cacheable_azure_credential()
206173
if self._credential:
207174
return self._credential.get_token(*_scopes, **_kwargs)
208175
return self._cached_token

‎metaflow/plugins/azure/blob_service_client_factory.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from metaflow.exception import MetaflowException
22
from metaflow.metaflow_config import AZURE_STORAGE_BLOB_SERVICE_ENDPOINT
33
from metaflow.plugins.azure.azure_utils import (
4-
create_cacheable_default_azure_credentials,
54
check_azure_deps,
65
)
6+
from metaflow.plugins.azure.azure_credential import (
7+
create_cacheable_azure_credential,
8+
)
79

810
import os
911
import threading
@@ -125,7 +127,7 @@ def get_azure_blob_service_client(
125127
blob_service_endpoint = AZURE_STORAGE_BLOB_SERVICE_ENDPOINT
126128

127129
if not credential:
128-
credential = create_cacheable_default_azure_credentials()
130+
credential = create_cacheable_azure_credential()
129131
credential_is_cacheable = True
130132

131133
if not credential_is_cacheable:

‎metaflow/plugins/datastores/azure_storage.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
handle_executor_exceptions,
3333
)
3434

35+
from metaflow.plugins.azure.azure_credential import create_cacheable_azure_credential
36+
3537
AZURE_STORAGE_DOWNLOAD_MAX_CONCURRENCY = 4
3638
AZURE_STORAGE_UPLOAD_MAX_CONCURRENCY = 16
3739

@@ -266,12 +268,10 @@ def _get_default_token(self):
266268
if not self._default_scope_token or (
267269
self._default_scope_token.expires_on - time.time() < 300
268270
):
269-
from azure.identity import DefaultAzureCredential
270-
271-
with DefaultAzureCredential() as credential:
272-
self._default_scope_token = credential.get_token(
273-
AZURE_STORAGE_DEFAULT_SCOPE
274-
)
271+
credential = create_cacheable_azure_credential()
272+
self._default_scope_token = credential.get_token(
273+
AZURE_STORAGE_DEFAULT_SCOPE
274+
)
275275
return self._default_scope_token
276276

277277
@property

0 commit comments

Comments
 (0)
Please sign in to comment.