Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_auth.py: 50%

52 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A collection Service functions for managing VMs on Azure.""" 

6 

7import logging 

8from base64 import b64decode 

9from collections.abc import Callable 

10from datetime import datetime 

11from typing import Any 

12 

13from azure.core.credentials import TokenCredential 

14from azure.identity import CertificateCredential, DefaultAzureCredential 

15from azure.keyvault.secrets import SecretClient 

16from pytz import UTC 

17 

18from mlos_bench.services.base_service import Service 

19from mlos_bench.services.types.authenticator_type import SupportsAuth 

20from mlos_bench.util import check_required_params 

21 

22_LOG = logging.getLogger(__name__) 

23 

24 

25class AzureAuthService(Service, SupportsAuth[TokenCredential]): 

26 """Helper methods to get access to Azure services.""" 

27 

28 _REQ_INTERVAL = 300 # = 5 min 

29 

30 def __init__( 

31 self, 

32 config: dict[str, Any] | None = None, 

33 global_config: dict[str, Any] | None = None, 

34 parent: Service | None = None, 

35 methods: dict[str, Callable] | list[Callable] | None = None, 

36 ): 

37 """ 

38 Create a new instance of Azure authentication services proxy. 

39 

40 Parameters 

41 ---------- 

42 config : dict 

43 Free-format dictionary that contains the benchmark environment 

44 configuration. 

45 global_config : dict 

46 Free-format dictionary of global parameters. 

47 parent : Service 

48 Parent service that can provide mixin functions. 

49 methods : Union[dict[str, Callable], list[Callable], None] 

50 New methods to register with the service. 

51 """ 

52 super().__init__( 

53 config, 

54 global_config, 

55 parent, 

56 self.merge_methods( 

57 methods, 

58 [ 

59 self.get_access_token, 

60 self.get_auth_headers, 

61 self.get_credential, 

62 ], 

63 ), 

64 ) 

65 

66 # This parameter can come from command line as strings, so conversion is needed. 

67 self._req_interval = float(self.config.get("tokenRequestInterval", self._REQ_INTERVAL)) 

68 

69 self._access_token = "RENEW *NOW*" 

70 self._token_expiration_ts = datetime.now(UTC) # Typically, some future timestamp. 

71 self._cred: TokenCredential | None = None 

72 

73 # Verify info required for SP auth early 

74 if "spClientId" in self.config: 

75 check_required_params( 

76 self.config, 

77 { 

78 "spClientId", 

79 "keyVaultName", 

80 "certName", 

81 "tenant", 

82 }, 

83 ) 

84 

85 def get_credential(self) -> TokenCredential: 

86 """Return the Azure SDK credential object.""" 

87 # Perform this initialization outside of __init__ so that environment loading tests 

88 # don't need to specifically mock keyvault interactions out 

89 if self._cred is not None: 

90 return self._cred 

91 

92 self._cred = DefaultAzureCredential() 

93 if "spClientId" not in self.config: 

94 return self._cred 

95 

96 sp_client_id = self.config["spClientId"] 

97 keyvault_name = self.config["keyVaultName"] 

98 cert_name = self.config["certName"] 

99 tenant_id = self.config["tenant"] 

100 _LOG.debug("Log in with Azure Service Principal %s", sp_client_id) 

101 

102 # Get a client for fetching cert info 

103 keyvault_secrets_client = SecretClient( 

104 vault_url=f"https://{keyvault_name}.vault.azure.net", 

105 credential=self._cred, 

106 ) 

107 

108 # The certificate private key data is stored as hidden "Secret" (not Key strangely) 

109 # in PKCS12 format, but we need to decode it. 

110 secret = keyvault_secrets_client.get_secret(cert_name) 

111 assert secret.value is not None 

112 cert_bytes = b64decode(secret.value) 

113 

114 # Reauthenticate as the service principal. 

115 self._cred = CertificateCredential( # pylint: disable=redefined-variable-type 

116 tenant_id=tenant_id, 

117 client_id=sp_client_id, 

118 certificate_data=cert_bytes, 

119 ) 

120 return self._cred 

121 

122 def get_access_token(self) -> str: 

123 """Get the access token from Azure CLI, if expired.""" 

124 ts_diff = (self._token_expiration_ts - datetime.now(UTC)).total_seconds() 

125 _LOG.debug("Time to renew the token: %.2f sec.", ts_diff) 

126 if ts_diff < self._req_interval: 

127 _LOG.debug("Request new accessToken") 

128 res = self.get_credential().get_token("https://management.azure.com/.default") 

129 self._token_expiration_ts = datetime.fromtimestamp(res.expires_on, tz=UTC) 

130 self._access_token = res.token 

131 _LOG.info("Got new accessToken. Expiration time: %s", self._token_expiration_ts) 

132 return self._access_token 

133 

134 def get_auth_headers(self) -> dict: 

135 """Get the authorization part of HTTP headers for REST API calls.""" 

136 return {"Authorization": "Bearer " + self.get_access_token()}