Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_auth.py: 50%
52 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""A collection Service functions for managing VMs on Azure."""
7import logging
8from base64 import b64decode
9from collections.abc import Callable
10from datetime import datetime
11from typing import Any
13from azure.core.credentials import TokenCredential
14from azure.identity import CertificateCredential, DefaultAzureCredential
15from azure.keyvault.secrets import SecretClient
16from pytz import UTC
18from mlos_bench.services.base_service import Service
19from mlos_bench.services.types.authenticator_type import SupportsAuth
20from mlos_bench.util import check_required_params
22_LOG = logging.getLogger(__name__)
25class AzureAuthService(Service, SupportsAuth[TokenCredential]):
26 """Helper methods to get access to Azure services."""
28 _REQ_INTERVAL = 300 # = 5 min
30 def __init__(
31 self,
32 config: dict[str, Any] | None = None,
33 global_config: dict[str, Any] | None = None,
34 parent: Service | None = None,
35 methods: dict[str, Callable] | list[Callable] | None = None,
36 ):
37 """
38 Create a new instance of Azure authentication services proxy.
40 Parameters
41 ----------
42 config : dict
43 Free-format dictionary that contains the benchmark environment
44 configuration.
45 global_config : dict
46 Free-format dictionary of global parameters.
47 parent : Service
48 Parent service that can provide mixin functions.
49 methods : Union[dict[str, Callable], list[Callable], None]
50 New methods to register with the service.
51 """
52 super().__init__(
53 config,
54 global_config,
55 parent,
56 self.merge_methods(
57 methods,
58 [
59 self.get_access_token,
60 self.get_auth_headers,
61 self.get_credential,
62 ],
63 ),
64 )
66 # This parameter can come from command line as strings, so conversion is needed.
67 self._req_interval = float(self.config.get("tokenRequestInterval", self._REQ_INTERVAL))
69 self._access_token = "RENEW *NOW*"
70 self._token_expiration_ts = datetime.now(UTC) # Typically, some future timestamp.
71 self._cred: TokenCredential | None = None
73 # Verify info required for SP auth early
74 if "spClientId" in self.config:
75 check_required_params(
76 self.config,
77 {
78 "spClientId",
79 "keyVaultName",
80 "certName",
81 "tenant",
82 },
83 )
85 def get_credential(self) -> TokenCredential:
86 """Return the Azure SDK credential object."""
87 # Perform this initialization outside of __init__ so that environment loading tests
88 # don't need to specifically mock keyvault interactions out
89 if self._cred is not None:
90 return self._cred
92 self._cred = DefaultAzureCredential()
93 if "spClientId" not in self.config:
94 return self._cred
96 sp_client_id = self.config["spClientId"]
97 keyvault_name = self.config["keyVaultName"]
98 cert_name = self.config["certName"]
99 tenant_id = self.config["tenant"]
100 _LOG.debug("Log in with Azure Service Principal %s", sp_client_id)
102 # Get a client for fetching cert info
103 keyvault_secrets_client = SecretClient(
104 vault_url=f"https://{keyvault_name}.vault.azure.net",
105 credential=self._cred,
106 )
108 # The certificate private key data is stored as hidden "Secret" (not Key strangely)
109 # in PKCS12 format, but we need to decode it.
110 secret = keyvault_secrets_client.get_secret(cert_name)
111 assert secret.value is not None
112 cert_bytes = b64decode(secret.value)
114 # Reauthenticate as the service principal.
115 self._cred = CertificateCredential( # pylint: disable=redefined-variable-type
116 tenant_id=tenant_id,
117 client_id=sp_client_id,
118 certificate_data=cert_bytes,
119 )
120 return self._cred
122 def get_access_token(self) -> str:
123 """Get the access token from Azure CLI, if expired."""
124 ts_diff = (self._token_expiration_ts - datetime.now(UTC)).total_seconds()
125 _LOG.debug("Time to renew the token: %.2f sec.", ts_diff)
126 if ts_diff < self._req_interval:
127 _LOG.debug("Request new accessToken")
128 res = self.get_credential().get_token("https://management.azure.com/.default")
129 self._token_expiration_ts = datetime.fromtimestamp(res.expires_on, tz=UTC)
130 self._access_token = res.token
131 _LOG.info("Got new accessToken. Expiration time: %s", self._token_expiration_ts)
132 return self._access_token
134 def get_auth_headers(self) -> dict:
135 """Get the authorization part of HTTP headers for REST API calls."""
136 return {"Authorization": "Bearer " + self.get_access_token()}