|
7 | 7 | MetaflowAzurePackageError,
|
8 | 8 | )
|
9 | 9 | from metaflow.exception import MetaflowInternalError, MetaflowException
|
| 10 | +from metaflow.plugins.azure.azure_credential import create_cacheable_azure_credential |
10 | 11 |
|
11 | 12 |
|
12 | 13 | def _check_and_init_azure_deps():
|
@@ -138,38 +139,6 @@ def _inner_func(*args, **kwargs):
|
138 | 139 | return _inner_func
|
139 | 140 |
|
140 | 141 |
|
141 |
| -@check_azure_deps |
142 |
| -def create_cacheable_default_azure_credentials(*args, **kwargs): |
143 |
| - """azure.identity.DefaultAzureCredential is not readily cacheable in a dictionary |
144 |
| - because it does not have a content based hash and equality implementations. |
145 |
| -
|
146 |
| - We implement a subclass CacheableDefaultAzureCredential to add them. |
147 |
| -
|
148 |
| - We need this because credentials will be part of the cache key in _ClientCache. |
149 |
| - """ |
150 |
| - from azure.identity import DefaultAzureCredential |
151 |
| - |
152 |
| - class CacheableDefaultAzureCredential(DefaultAzureCredential): |
153 |
| - def __init__(self, *args, **kwargs): |
154 |
| - super(CacheableDefaultAzureCredential, self).__init__(*args, **kwargs) |
155 |
| - # Just hashing all the kwargs works because they are all individually |
156 |
| - # hashable as of 7/15/2022. |
157 |
| - # |
158 |
| - # What if Azure adds unhashable things to kwargs? |
159 |
| - # - We will have CI to catch this (it will always install the latest Azure SDKs) |
160 |
| - # - In Metaflow usage today we never specify any kwargs anyway. (see last line |
161 |
| - # of the outer function. |
162 |
| - self._hash_code = hash((args, tuple(sorted(kwargs.items())))) |
163 |
| - |
164 |
| - def __hash__(self): |
165 |
| - return self._hash_code |
166 |
| - |
167 |
| - def __eq__(self, other): |
168 |
| - return hash(self) == hash(other) |
169 |
| - |
170 |
| - return CacheableDefaultAzureCredential(*args, **kwargs) |
171 |
| - |
172 |
| - |
173 | 142 | @check_azure_deps
|
174 | 143 | def create_static_token_credential(token_):
|
175 | 144 | from azure.core.credentials import TokenCredential
|
@@ -200,9 +169,7 @@ def __init__(self, token):
|
200 | 169 | def get_token(self, *_scopes, **_kwargs):
|
201 | 170 |
|
202 | 171 | if (self._cached_token.expires_on - time.time()) < 300:
|
203 |
| - from azure.identity import DefaultAzureCredential |
204 |
| - |
205 |
| - self._credential = DefaultAzureCredential() |
| 172 | + self._credential = create_cacheable_azure_credential() |
206 | 173 | if self._credential:
|
207 | 174 | return self._credential.get_token(*_scopes, **_kwargs)
|
208 | 175 | return self._cached_token
|
|
0 commit comments