Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_auth.py: 49%
51 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""A collection Service functions for managing VMs on Azure."""
7import logging
8from base64 import b64decode
9from datetime import datetime
10from typing import Any, Callable, Dict, List, Optional, Union
12from azure.core.credentials import TokenCredential
13from azure.identity import CertificateCredential, DefaultAzureCredential
14from azure.keyvault.secrets import SecretClient
15from pytz import UTC
17from mlos_bench.services.base_service import Service
18from mlos_bench.services.types.authenticator_type import SupportsAuth
19from mlos_bench.util import check_required_params
21_LOG = logging.getLogger(__name__)
24class AzureAuthService(Service, SupportsAuth[TokenCredential]):
25 """Helper methods to get access to Azure services."""
27 _REQ_INTERVAL = 300 # = 5 min
29 def __init__(
30 self,
31 config: Optional[Dict[str, Any]] = None,
32 global_config: Optional[Dict[str, Any]] = None,
33 parent: Optional[Service] = None,
34 methods: Union[Dict[str, Callable], List[Callable], None] = None,
35 ):
36 """
37 Create a new instance of Azure authentication services proxy.
39 Parameters
40 ----------
41 config : dict
42 Free-format dictionary that contains the benchmark environment
43 configuration.
44 global_config : dict
45 Free-format dictionary of global parameters.
46 parent : Service
47 Parent service that can provide mixin functions.
48 methods : Union[Dict[str, Callable], List[Callable], None]
49 New methods to register with the service.
50 """
51 super().__init__(
52 config,
53 global_config,
54 parent,
55 self.merge_methods(
56 methods,
57 [
58 self.get_access_token,
59 self.get_auth_headers,
60 self.get_credential,
61 ],
62 ),
63 )
65 # This parameter can come from command line as strings, so conversion is needed.
66 self._req_interval = float(self.config.get("tokenRequestInterval", self._REQ_INTERVAL))
68 self._access_token = "RENEW *NOW*"
69 self._token_expiration_ts = datetime.now(UTC) # Typically, some future timestamp.
70 self._cred: Optional[TokenCredential] = None
72 # Verify info required for SP auth early
73 if "spClientId" in self.config:
74 check_required_params(
75 self.config,
76 {
77 "spClientId",
78 "keyVaultName",
79 "certName",
80 "tenant",
81 },
82 )
84 def get_credential(self) -> TokenCredential:
85 """Return the Azure SDK credential object."""
86 # Perform this initialization outside of __init__ so that environment loading tests
87 # don't need to specifically mock keyvault interactions out
88 if self._cred is not None:
89 return self._cred
91 self._cred = DefaultAzureCredential()
92 if "spClientId" not in self.config:
93 return self._cred
95 sp_client_id = self.config["spClientId"]
96 keyvault_name = self.config["keyVaultName"]
97 cert_name = self.config["certName"]
98 tenant_id = self.config["tenant"]
99 _LOG.debug("Log in with Azure Service Principal %s", sp_client_id)
101 # Get a client for fetching cert info
102 keyvault_secrets_client = SecretClient(
103 vault_url=f"https://{keyvault_name}.vault.azure.net",
104 credential=self._cred,
105 )
107 # The certificate private key data is stored as hidden "Secret" (not Key strangely)
108 # in PKCS12 format, but we need to decode it.
109 secret = keyvault_secrets_client.get_secret(cert_name)
110 assert secret.value is not None
111 cert_bytes = b64decode(secret.value)
113 # Reauthenticate as the service principal.
114 self._cred = CertificateCredential( # pylint: disable=redefined-variable-type
115 tenant_id=tenant_id,
116 client_id=sp_client_id,
117 certificate_data=cert_bytes,
118 )
119 return self._cred
121 def get_access_token(self) -> str:
122 """Get the access token from Azure CLI, if expired."""
123 ts_diff = (self._token_expiration_ts - datetime.now(UTC)).total_seconds()
124 _LOG.debug("Time to renew the token: %.2f sec.", ts_diff)
125 if ts_diff < self._req_interval:
126 _LOG.debug("Request new accessToken")
127 res = self.get_credential().get_token("https://management.azure.com/.default")
128 self._token_expiration_ts = datetime.fromtimestamp(res.expires_on, tz=UTC)
129 self._access_token = res.token
130 _LOG.info("Got new accessToken. Expiration time: %s", self._token_expiration_ts)
131 return self._access_token
133 def get_auth_headers(self) -> dict:
134 """Get the authorization part of HTTP headers for REST API calls."""
135 return {"Authorization": "Bearer " + self.get_access_token()}