Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add credential support on Azure platform #3678

Merged
merged 10 commits into from
Feb 27, 2025
19 changes: 12 additions & 7 deletions lisa/sut_orchestrator/azure/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2300,7 +2300,7 @@ def find_storage_account(

def get_token(platform: "AzurePlatform") -> str:
token = platform.credential.get_token(platform.cloud.endpoints.resource_manager)
return token.token
return str(token.token)


def _generate_sas_token_for_vhd(
Expand Down Expand Up @@ -2782,12 +2782,17 @@ def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
return AccessToken(self._token, self._expires_on)

def _get_exp(self) -> Any:
# The second part of the JWT is the payload
payload = self._token.split(".")[1]
# Add padding to ensure Base64 decoding works properly
padded_payload = payload + "=" * (4 - len(payload) % 4)
# Decode the Base64 URL-safe encoded payload
decoded_payload = base64.urlsafe_b64decode(padded_payload)
try:
# The second part of the JWT is the payload
payload = self._token.split(".")[1]
# Add padding to ensure Base64 decoding works properly
padded_payload = payload + "=" * (4 - len(payload) % 4)
# Decode the Base64 URL-safe encoded payload
decoded_payload = base64.urlsafe_b64decode(padded_payload)
except Exception as e:
raise LisaException(
f"Failed to decode JWT payload, maybe invalid token: {e}"
)
# Convert the payload into a dictionary and get the expiration time
# 'exp' is the UNIX timestamp for expiration
return json.loads(decoded_payload).get("exp")
Expand Down
194 changes: 147 additions & 47 deletions lisa/sut_orchestrator/azure/credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,49 +11,64 @@
ManagedIdentityCredential,
)
from dataclasses_json import dataclass_json
from msrestazure.azure_cloud import AZURE_PUBLIC_CLOUD, Cloud # type: ignore

from lisa import schema
from lisa.util import constants, subclasses
from lisa import schema, secret
from lisa.util import subclasses
from lisa.util.logger import Logger

from .common import get_static_access_token


class AzureCredentialType(str, Enum):
DefaultAzureCredential = constants.DEFAULT_AZURE_CREDENTIAL
CertificateCredential = constants.CERTIFICATE_CREDENTIAL
ClientAssertionCredential = constants.CLIENT_ASSERTION_CREDENTIAL
ClientSecretCredential = constants.CLIENT_SECRET_CREDENTIAL
DefaultAzureCredential = "default"
CertificateCredential = "certificate"
ClientAssertionCredential = "assertion"
ClientSecretCredential = "secret"
TokenCredential = "token"


@dataclass_json()
@dataclass
class AzureCredentialSchema(schema.TypedSchema, schema.ExtendableSchemaMixin):
type: str = AzureCredentialType.DefaultAzureCredential
tenant_id: str = ""
client_id: str = ""
type: str = AzureCredentialType.DefaultAzureCredential


@dataclass_json()
@dataclass
class CertCredentialSchema(AzureCredentialSchema):
cert_path: str = ""
client_send_cert_chain = "false"
type: str = AzureCredentialType.CertificateCredential


@dataclass_json()
@dataclass
class ClientAssertionCredentialSchema(AzureCredentialSchema):
msi_client_id: str = ""
enterprise_app_client_id: str = ""
type: str = AzureCredentialType.ClientAssertionCredential


@dataclass_json()
@dataclass
class ClientSecretCredentialSchema(AzureCredentialSchema):
# for ClientSecretCredential, will be deprecated due to Security WAVE
client_secret: str = ""
type: str = AzureCredentialType.ClientSecretCredential

def __post_init__(self) -> None:
assert self.client_secret, "client_secret shouldn't be empty"
secret.add_secret(self.client_secret)


@dataclass_json()
@dataclass
class TokenCredentialSchema(AzureCredentialSchema):
token: str = ""

def __post_init__(self) -> None:
assert self.token, "token shouldn't be empty"
secret.add_secret(self.token)


class AzureCredential(subclasses.BaseClassWithRunbookMixin):
Expand All @@ -63,66 +78,115 @@ class AzureCredential(subclasses.BaseClassWithRunbookMixin):

@classmethod
def type_name(cls) -> str:
return constants.DEFAULT_AZURE_CREDENTIAL
raise NotImplementedError()

@classmethod
def type_schema(cls) -> Type[schema.TypedSchema]:
return AzureCredentialSchema

def __init__(self, runbook: AzureCredentialSchema) -> None:
raise NotImplementedError()

def __init__(
self,
runbook: AzureCredentialSchema,
logger: Logger,
cloud: Cloud = AZURE_PUBLIC_CLOUD,
) -> None:
super().__init__(runbook=runbook)
self._log = logger

if runbook.type:
self._credential_type = runbook.type
else:
self._credential_type = AzureCredentialType.DefaultAzureCredential

self._log.debug(f"Credential type: {self._credential_type}")
self._cloud = cloud

# parameters overwrite seq: env var <- runbook <- cmd
self._credential_type: str = AzureCredentialType.DefaultAzureCredential
self._client_id = os.environ.get("AZURE_CLIENT_ID", "")
self._tenant_id = os.environ.get("AZURE_TENANT_ID", "")
self._client_id = os.environ.get("AZURE_CLIENT_ID", "")

assert runbook, "azure_credential shouldn't be empty"
self._azure_credential = runbook
if runbook.type:
self._credential_type = runbook.type
if runbook.client_id:
self._client_id = runbook.client_id
if runbook.tenant_id:
self._tenant_id = runbook.tenant_id
self._log.debug(f"Use defined tenant id: {self._tenant_id}")
os.environ["AZURE_TENANT_ID"] = self._tenant_id
if runbook.client_id:
self._client_id = runbook.client_id
self._log.debug(f"Use defined client id: {self._client_id}")
os.environ["AZURE_CLIENT_ID"] = self._client_id

def __hash__(self) -> int:
return hash(self._get_key())

def get_credential(self) -> Any:
raise NotImplementedError()

def _get_key(self) -> str:
return f"{self._credential_type}_{self._client_id}_{self._tenant_id}"


def get_credential(self, log: Logger) -> Any:
class AzureDefaultCredential(AzureCredential):
"""
Class to create DefaultAzureCredential based on runbook Schema. Because the
subclass factory doesn't instance the base class, so create a subclass to be
instanced.
"""

@classmethod
def type_name(cls) -> str:
return AzureCredentialType.DefaultAzureCredential

@classmethod
def type_schema(cls) -> Type[schema.TypedSchema]:
return AzureCredentialSchema

def __hash__(self) -> int:
return hash(self._get_key())

def get_credential(self) -> Any:
"""
return AzureCredential with related schema
"""
log.info("Authenticating using DefaultAzureCredential")
return DefaultAzureCredential()
return DefaultAzureCredential(cloud=self._cloud)

def _get_key(self) -> str:
return f"{self._credential_type}_{self._client_id}_{self._tenant_id}"


class AzureCertificateCredential(AzureCredential):
"""
Class to create azure credential based on runbook AzureCredentialSchema.
Methods:
get_credential(self, log: Logger) -> Any:
return the credential based on runbook AzureCredentialSchema define.
"""

@classmethod
def type_name(cls) -> str:
return constants.CERTIFICATE_CREDENTIAL
return AzureCredentialType.CertificateCredential

@classmethod
def type_schema(cls) -> Type[schema.TypedSchema]:
return CertCredentialSchema

def __init__(self, runbook: CertCredentialSchema) -> None:
super().__init__(runbook)
def __init__(
self,
runbook: CertCredentialSchema,
logger: Logger,
cloud: Cloud = AZURE_PUBLIC_CLOUD,
) -> None:
super().__init__(runbook, cloud=cloud, logger=logger)
self._cert_path = os.environ.get("AZURE_CLIENT_CERTIFICATE_PATH", "")
self._client_send_cert_chain = "false"

runbook = cast(CertCredentialSchema, self.runbook)
self._credential_type = AzureCredentialType.CertificateCredential
if runbook.cert_path:
self._cert_path = runbook.cert_path
self._log.debug(f"Use defined cert path: {self._cert_path}")
os.environ["AZURE_CLIENT_CERTIFICATE_PATH"] = self._cert_path
if runbook.client_send_cert_chain:
self._client_send_cert_chain = runbook.client_send_cert_chain

def get_credential(self, log: Logger) -> Any:
log.info(f"Authenticating using cert path: {self._cert_path}")
def get_credential(self) -> Any:
self._log.info(f"Authenticating using cert path: {self._cert_path}")

assert self._tenant_id, "tenant id shouldn't be none for CertificateCredential"
assert self._client_id, "client id shouldn't be none for CertificateCredential"
Expand All @@ -143,15 +207,20 @@ class AzureClientAssertionCredential(AzureCredential):

@classmethod
def type_name(cls) -> str:
return constants.CLIENT_ASSERTION_CREDENTIAL
return AzureCredentialType.ClientAssertionCredential

@classmethod
def type_schema(cls) -> Type[schema.TypedSchema]:
return ClientAssertionCredentialSchema

def __init__(self, runbook: ClientAssertionCredentialSchema) -> None:
def __init__(
self,
runbook: ClientAssertionCredentialSchema,
logger: Logger,
cloud: Cloud = AZURE_PUBLIC_CLOUD,
) -> None:
if runbook:
super().__init__(runbook)
super().__init__(runbook, cloud=cloud, logger=logger)
self._msi_client_id = ""
self._enterprise_app_client_id = ""
self._credential_type = AzureCredentialType.ClientAssertionCredential
Expand Down Expand Up @@ -185,8 +254,8 @@ def get_cross_tenant_credential(
)
return credential

def get_credential(self, log: Logger) -> Any:
log.info("Authenticating using ClientAssertionCredential")
def get_credential(self) -> Any:
self._log.info("Authenticating using ClientAssertionCredential")
return self.get_cross_tenant_credential(
self._msi_client_id, self._enterprise_app_client_id, self._tenant_id
)
Expand All @@ -196,30 +265,35 @@ class AzureClientSecretCredential(AzureCredential):
"""
Class to create ClientSecretCredential based on runbook Schema
Methods:
get_credential(self, log: Logger) -> Any:
get_credential(self) -> Any:
return the credential based on runbook Schema define.
"""

@classmethod
def type_name(cls) -> str:
return constants.CLIENT_SECRET_CREDENTIAL
return AzureCredentialType.ClientSecretCredential

@classmethod
def type_schema(cls) -> Type[schema.TypedSchema]:
return ClientSecretCredentialSchema

def __init__(self, runbook: ClientSecretCredentialSchema) -> None:
super().__init__(runbook)
def __init__(
self,
runbook: ClientSecretCredentialSchema,
logger: Logger,
cloud: Cloud = AZURE_PUBLIC_CLOUD,
) -> None:
super().__init__(runbook, cloud=cloud, logger=logger)
self._credential_type = AzureCredentialType.ClientSecretCredential
self._client_secret = os.environ.get("AZURE_CLIENT_SECRET", "")

runbook = cast(ClientSecretCredentialSchema, self.runbook)
if runbook.client_id:
self._client_id = runbook.client_id
if runbook.tenant_id:
self._tenant_id = runbook.tenant_id
if runbook.client_secret:
self._client_secret = runbook.client_secret
self._log.debug(
f"Use defined client secret: ({len(self._client_secret)} bytes)"
)
os.environ["AZURE_CLIENT_SECRET"] = self._client_secret

def get_client_secret_credential(
self, tenant_id: str, client_id: str, client_secret: str
Expand All @@ -237,8 +311,34 @@ def get_client_secret_credential(
tenant_id=tenant_id, client_id=client_id, client_secret=client_secret
)

def get_credential(self, log: Logger) -> Any:
log.info("Authenticating using ClientSecretCredential")
def get_credential(self) -> Any:
self._log.info("Authenticating using ClientSecretCredential")
return self.get_client_secret_credential(
self._tenant_id, self._client_id, self._client_secret
)


class AzureTokenCredential(AzureCredential):
"""
Class to create azure credential based on preappled tokens
"""

@classmethod
def type_name(cls) -> str:
return AzureCredentialType.TokenCredential

@classmethod
def type_schema(cls) -> Type[schema.TypedSchema]:
return TokenCredentialSchema

def __init__(
self,
runbook: TokenCredentialSchema,
logger: Logger,
cloud: Cloud = AZURE_PUBLIC_CLOUD,
) -> None:
super().__init__(runbook, cloud=cloud, logger=logger)
self._token = runbook.token

def get_credential(self) -> Any:
return get_static_access_token(self._token)
2 changes: 1 addition & 1 deletion lisa/sut_orchestrator/azure/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def _get_access_token(self) -> str:
"https://management.core.windows.net/.default"
).token

return access_token
return str(access_token)

def _get_console_log(self, saved_path: Optional[Path]) -> bytes:
platform: AzurePlatform = self._platform # type: ignore
Expand Down
Loading
Loading