Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_auth.py: 49%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A collection Service functions for managing VMs on Azure.""" 

6 

7import logging 

8from base64 import b64decode 

9from datetime import datetime 

10from typing import Any, Callable, Dict, List, Optional, Union 

11 

12from azure.core.credentials import TokenCredential 

13from azure.identity import CertificateCredential, DefaultAzureCredential 

14from azure.keyvault.secrets import SecretClient 

15from pytz import UTC 

16 

17from mlos_bench.services.base_service import Service 

18from mlos_bench.services.types.authenticator_type import SupportsAuth 

19from mlos_bench.util import check_required_params 

20 

21_LOG = logging.getLogger(__name__) 

22 

23 

24class AzureAuthService(Service, SupportsAuth[TokenCredential]): 

25 """Helper methods to get access to Azure services.""" 

26 

27 _REQ_INTERVAL = 300 # = 5 min 

28 

29 def __init__( 

30 self, 

31 config: Optional[Dict[str, Any]] = None, 

32 global_config: Optional[Dict[str, Any]] = None, 

33 parent: Optional[Service] = None, 

34 methods: Union[Dict[str, Callable], List[Callable], None] = None, 

35 ): 

36 """ 

37 Create a new instance of Azure authentication services proxy. 

38 

39 Parameters 

40 ---------- 

41 config : dict 

42 Free-format dictionary that contains the benchmark environment 

43 configuration. 

44 global_config : dict 

45 Free-format dictionary of global parameters. 

46 parent : Service 

47 Parent service that can provide mixin functions. 

48 methods : Union[Dict[str, Callable], List[Callable], None] 

49 New methods to register with the service. 

50 """ 

51 super().__init__( 

52 config, 

53 global_config, 

54 parent, 

55 self.merge_methods( 

56 methods, 

57 [ 

58 self.get_access_token, 

59 self.get_auth_headers, 

60 self.get_credential, 

61 ], 

62 ), 

63 ) 

64 

65 # This parameter can come from command line as strings, so conversion is needed. 

66 self._req_interval = float(self.config.get("tokenRequestInterval", self._REQ_INTERVAL)) 

67 

68 self._access_token = "RENEW *NOW*" 

69 self._token_expiration_ts = datetime.now(UTC) # Typically, some future timestamp. 

70 self._cred: Optional[TokenCredential] = None 

71 

72 # Verify info required for SP auth early 

73 if "spClientId" in self.config: 

74 check_required_params( 

75 self.config, 

76 { 

77 "spClientId", 

78 "keyVaultName", 

79 "certName", 

80 "tenant", 

81 }, 

82 ) 

83 

84 def get_credential(self) -> TokenCredential: 

85 """Return the Azure SDK credential object.""" 

86 # Perform this initialization outside of __init__ so that environment loading tests 

87 # don't need to specifically mock keyvault interactions out 

88 if self._cred is not None: 

89 return self._cred 

90 

91 self._cred = DefaultAzureCredential() 

92 if "spClientId" not in self.config: 

93 return self._cred 

94 

95 sp_client_id = self.config["spClientId"] 

96 keyvault_name = self.config["keyVaultName"] 

97 cert_name = self.config["certName"] 

98 tenant_id = self.config["tenant"] 

99 _LOG.debug("Log in with Azure Service Principal %s", sp_client_id) 

100 

101 # Get a client for fetching cert info 

102 keyvault_secrets_client = SecretClient( 

103 vault_url=f"https://{keyvault_name}.vault.azure.net", 

104 credential=self._cred, 

105 ) 

106 

107 # The certificate private key data is stored as hidden "Secret" (not Key strangely) 

108 # in PKCS12 format, but we need to decode it. 

109 secret = keyvault_secrets_client.get_secret(cert_name) 

110 assert secret.value is not None 

111 cert_bytes = b64decode(secret.value) 

112 

113 # Reauthenticate as the service principal. 

114 self._cred = CertificateCredential( # pylint: disable=redefined-variable-type 

115 tenant_id=tenant_id, 

116 client_id=sp_client_id, 

117 certificate_data=cert_bytes, 

118 ) 

119 return self._cred 

120 

121 def get_access_token(self) -> str: 

122 """Get the access token from Azure CLI, if expired.""" 

123 ts_diff = (self._token_expiration_ts - datetime.now(UTC)).total_seconds() 

124 _LOG.debug("Time to renew the token: %.2f sec.", ts_diff) 

125 if ts_diff < self._req_interval: 

126 _LOG.debug("Request new accessToken") 

127 res = self.get_credential().get_token("https://management.azure.com/.default") 

128 self._token_expiration_ts = datetime.fromtimestamp(res.expires_on, tz=UTC) 

129 self._access_token = res.token 

130 _LOG.info("Got new accessToken. Expiration time: %s", self._token_expiration_ts) 

131 return self._access_token 

132 

133 def get_auth_headers(self) -> dict: 

134 """Get the authorization part of HTTP headers for REST API calls.""" 

135 return {"Authorization": "Bearer " + self.get_access_token()}