Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_saas.py: 46%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A collection Service functions for configuring SaaS instances on Azure.""" 

6import logging 

7from collections.abc import Callable 

8from typing import Any 

9 

10import requests 

11 

12from mlos_bench.environments.status import Status 

13from mlos_bench.services.base_service import Service 

14from mlos_bench.services.types.authenticator_type import SupportsAuth 

15from mlos_bench.services.types.remote_config_type import SupportsRemoteConfig 

16from mlos_bench.util import check_required_params, merge_parameters 

17 

18_LOG = logging.getLogger(__name__) 

19 

20 

21class AzureSaaSConfigService(Service, SupportsRemoteConfig): 

22 """Helper methods to configure Azure Flex services.""" 

23 

24 _REQUEST_TIMEOUT = 5 # seconds 

25 

26 # REST API for Azure SaaS DB Services configuration as described in: 

27 # https://learn.microsoft.com/en-us/rest/api/mysql/flexibleserver/configurations 

28 # https://learn.microsoft.com/en-us/rest/api/postgresql/flexibleserver/configurations 

29 # https://learn.microsoft.com/en-us/rest/api/mariadb/configurations 

30 

31 _URL_CONFIGURE = ( 

32 "https://management.azure.com" 

33 "/subscriptions/{subscription}" 

34 "/resourceGroups/{resource_group}" 

35 "/providers/{provider}" 

36 "/{server_type}/{vm_name}" 

37 "/{update}" 

38 "?api-version={api_version}" 

39 ) 

40 

41 def __init__( 

42 self, 

43 config: dict[str, Any] | None = None, 

44 global_config: dict[str, Any] | None = None, 

45 parent: Service | None = None, 

46 methods: dict[str, Callable] | list[Callable] | None = None, 

47 ): 

48 """ 

49 Create a new instance of Azure services proxy. 

50 

51 Parameters 

52 ---------- 

53 config : dict 

54 Free-format dictionary that contains the benchmark environment 

55 configuration. 

56 global_config : dict 

57 Free-format dictionary of global parameters. 

58 parent : Service 

59 Parent service that can provide mixin functions. 

60 methods : Union[dict[str, Callable], list[Callable], None] 

61 New methods to register with the service. 

62 """ 

63 super().__init__( 

64 config, 

65 global_config, 

66 parent, 

67 self.merge_methods(methods, [self.configure, self.is_config_pending]), 

68 ) 

69 

70 check_required_params( 

71 self.config, 

72 { 

73 "subscription", 

74 "resourceGroup", 

75 "provider", 

76 }, 

77 ) 

78 

79 # Provide sane defaults for known DB providers. 

80 provider = self.config.get("provider") 

81 if provider == "Microsoft.DBforMySQL": 

82 self._is_batch = self.config.get("supportsBatchUpdate", True) 

83 is_flex = self.config.get("isFlex", True) 

84 api_version = self.config.get("apiVersion", "2022-01-01") 

85 elif provider == "Microsoft.DBforMariaDB": 

86 self._is_batch = self.config.get("supportsBatchUpdate", False) 

87 is_flex = self.config.get("isFlex", False) 

88 api_version = self.config.get("apiVersion", "2018-06-01") 

89 elif provider == "Microsoft.DBforPostgreSQL": 

90 self._is_batch = self.config.get("supportsBatchUpdate", False) 

91 is_flex = self.config.get("isFlex", True) 

92 api_version = self.config.get("apiVersion", "2022-12-01") 

93 else: 

94 self._is_batch = self.config["supportsBatchUpdate"] 

95 is_flex = self.config["isFlex"] 

96 api_version = self.config["apiVersion"] 

97 

98 self._url_config_set = self._URL_CONFIGURE.format( 

99 subscription=self.config["subscription"], 

100 resource_group=self.config["resourceGroup"], 

101 provider=self.config["provider"], 

102 vm_name="{vm_name}", 

103 server_type="flexibleServers" if is_flex else "servers", 

104 update="updateConfigurations" if self._is_batch else "configurations/{param_name}", 

105 api_version=api_version, 

106 ) 

107 

108 self._url_config_get = self._URL_CONFIGURE.format( 

109 subscription=self.config["subscription"], 

110 resource_group=self.config["resourceGroup"], 

111 provider=self.config["provider"], 

112 vm_name="{vm_name}", 

113 server_type="flexibleServers" if is_flex else "servers", 

114 update="configurations", 

115 api_version=api_version, 

116 ) 

117 

118 # These parameters can come from command line as strings, so conversion is needed. 

119 self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) 

120 

121 def configure(self, config: dict[str, Any], params: dict[str, Any]) -> tuple[Status, dict]: 

122 """ 

123 Update the parameters of an Azure DB service. 

124 

125 Parameters 

126 ---------- 

127 config : dict[str, Any] 

128 Key/value pairs of configuration parameters (e.g., vmName). 

129 params : dict[str, Any] 

130 Key/value pairs of the service parameters to update. 

131 

132 Returns 

133 ------- 

134 result : (Status, dict) 

135 A pair of Status and result. The result is always {}. 

136 Status is one of {PENDING, SUCCEEDED, FAILED} 

137 """ 

138 if self._is_batch: 

139 return self._config_batch(config, params) 

140 return self._config_many(config, params) 

141 

142 def is_config_pending(self, config: dict[str, Any]) -> tuple[Status, dict]: 

143 """ 

144 Check if the configuration of an Azure DB service requires a reboot or restart. 

145 

146 Parameters 

147 ---------- 

148 config : dict[str, Any] 

149 Key/value pairs of configuration parameters (e.g., vmName). 

150 

151 Returns 

152 ------- 

153 result : (Status, dict) 

154 A pair of Status and result. A Boolean field 

155 "isConfigPendingRestart" indicates whether the service restart is required. 

156 If "isConfigPendingReboot" is set to True, rebooting a VM is necessary. 

157 Status is one of {PENDING, TIMED_OUT, SUCCEEDED, FAILED} 

158 """ 

159 config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"]) 

160 url = self._url_config_get.format(vm_name=config["vmName"]) 

161 _LOG.debug("Request: GET %s", url) 

162 response = requests.put(url, headers=self._get_headers(), timeout=self._request_timeout) 

163 _LOG.debug("Response: %s :: %s", response, response.text) 

164 if response.status_code == 504: 

165 return (Status.TIMED_OUT, {}) 

166 if response.status_code != 200: 

167 return (Status.FAILED, {}) 

168 # Currently, Azure Flex servers require a VM reboot. 

169 return ( 

170 Status.SUCCEEDED, 

171 { 

172 "isConfigPendingReboot": any( 

173 {"False": False, "True": True}[val["properties"]["isConfigPendingRestart"]] 

174 for val in response.json()["value"] 

175 ) 

176 }, 

177 ) 

178 

179 def _get_headers(self) -> dict: 

180 """Get the headers for the REST API calls.""" 

181 assert self._parent is not None and isinstance( 

182 self._parent, SupportsAuth 

183 ), "Authorization service not provided. Include service-auth.jsonc?" 

184 return self._parent.get_auth_headers() 

185 

186 def _config_one( 

187 self, 

188 config: dict[str, Any], 

189 param_name: str, 

190 param_value: Any, 

191 ) -> tuple[Status, dict]: 

192 """ 

193 Update a single parameter of the Azure DB service. 

194 

195 Parameters 

196 ---------- 

197 config : dict[str, Any] 

198 Key/value pairs of configuration parameters (e.g., vmName). 

199 param_name : str 

200 Name of the parameter to update. 

201 param_value : Any 

202 Value of the parameter to update. 

203 

204 Returns 

205 ------- 

206 result : (Status, dict) 

207 A pair of Status and result. The result is always {}. 

208 Status is one of {PENDING, SUCCEEDED, FAILED} 

209 """ 

210 config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"]) 

211 url = self._url_config_set.format(vm_name=config["vmName"], param_name=param_name) 

212 _LOG.debug("Request: PUT %s", url) 

213 response = requests.put( 

214 url, 

215 headers=self._get_headers(), 

216 json={"properties": {"value": str(param_value)}}, 

217 timeout=self._request_timeout, 

218 ) 

219 _LOG.debug("Response: %s :: %s", response, response.text) 

220 if response.status_code == 504: 

221 return (Status.TIMED_OUT, {}) 

222 if response.status_code == 200: 

223 return (Status.SUCCEEDED, {}) 

224 return (Status.FAILED, {}) 

225 

226 def _config_many(self, config: dict[str, Any], params: dict[str, Any]) -> tuple[Status, dict]: 

227 """ 

228 Update the parameters of an Azure DB service one-by-one. (If batch API is not 

229 available for it). 

230 

231 Parameters 

232 ---------- 

233 config : dict[str, Any] 

234 Key/value pairs of configuration parameters (e.g., vmName). 

235 params : dict[str, Any] 

236 Key/value pairs of the service parameters to update. 

237 

238 Returns 

239 ------- 

240 result : (Status, dict) 

241 A pair of Status and result. The result is always {}. 

242 Status is one of {PENDING, SUCCEEDED, FAILED} 

243 """ 

244 for param_name, param_value in params.items(): 

245 (status, result) = self._config_one(config, param_name, param_value) 

246 if not status.is_succeeded(): 

247 return (status, result) 

248 return (Status.SUCCEEDED, {}) 

249 

250 def _config_batch(self, config: dict[str, Any], params: dict[str, Any]) -> tuple[Status, dict]: 

251 """ 

252 Batch update the parameters of an Azure DB service. 

253 

254 Parameters 

255 ---------- 

256 config : dict[str, Any] 

257 Key/value pairs of configuration parameters (e.g., vmName). 

258 params : dict[str, Any] 

259 Key/value pairs of the service parameters to update. 

260 

261 Returns 

262 ------- 

263 result : (Status, dict) 

264 A pair of Status and result. The result is always {}. 

265 Status is one of {PENDING, SUCCEEDED, FAILED} 

266 """ 

267 config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"]) 

268 url = self._url_config_set.format(vm_name=config["vmName"]) 

269 json_req = { 

270 "value": [ 

271 {"name": key, "properties": {"value": str(val)}} for (key, val) in params.items() 

272 ], 

273 # "resetAllToDefault": "True" 

274 } 

275 _LOG.debug("Request: POST %s", url) 

276 response = requests.post( 

277 url, 

278 headers=self._get_headers(), 

279 json=json_req, 

280 timeout=self._request_timeout, 

281 ) 

282 _LOG.debug("Response: %s :: %s", response, response.text) 

283 if response.status_code == 504: 

284 return (Status.TIMED_OUT, {}) 

285 if response.status_code == 200: 

286 return (Status.SUCCEEDED, {}) 

287 return (Status.FAILED, {})