Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py: 76%

178 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Base class for certain Azure Services classes that do deployments.""" 

6 

7import abc 

8import json 

9import logging 

10import time 

11from typing import Any, Callable, Dict, List, Optional, Tuple, Union 

12 

13import requests 

14from requests.adapters import HTTPAdapter, Retry 

15 

16from mlos_bench.dict_templater import DictTemplater 

17from mlos_bench.environments.status import Status 

18from mlos_bench.services.base_service import Service 

19from mlos_bench.services.types.authenticator_type import SupportsAuth 

20from mlos_bench.util import check_required_params, merge_parameters 

21 

22_LOG = logging.getLogger(__name__) 

23 

24 

25class AzureDeploymentService(Service, metaclass=abc.ABCMeta): 

26 """Helper methods to manage and deploy Azure resources via REST APIs.""" 

27 

28 _POLL_INTERVAL = 4 # seconds 

29 _POLL_TIMEOUT = 300 # seconds 

30 _REQUEST_TIMEOUT = 5 # seconds 

31 _REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request 

32 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries})) 

33 _REQUEST_RETRY_BACKOFF_FACTOR = 0.3 

34 

35 # Azure Resources Deployment REST API as described in 

36 # https://docs.microsoft.com/en-us/rest/api/resources/deployments 

37 

38 _URL_DEPLOY = ( 

39 "https://management.azure.com" 

40 "/subscriptions/{subscription}" 

41 "/resourceGroups/{resource_group}" 

42 "/providers/Microsoft.Resources" 

43 "/deployments/{deployment_name}" 

44 "?api-version=2022-05-01" 

45 ) 

46 

47 def __init__( 

48 self, 

49 config: Optional[Dict[str, Any]] = None, 

50 global_config: Optional[Dict[str, Any]] = None, 

51 parent: Optional[Service] = None, 

52 methods: Union[Dict[str, Callable], List[Callable], None] = None, 

53 ): 

54 """ 

55 Create a new instance of an Azure Services proxy. 

56 

57 Parameters 

58 ---------- 

59 config : dict 

60 Free-format dictionary that contains the benchmark environment 

61 configuration. 

62 global_config : dict 

63 Free-format dictionary of global parameters. 

64 parent : Service 

65 Parent service that can provide mixin functions. 

66 methods : Union[Dict[str, Callable], List[Callable], None] 

67 New methods to register with the service. 

68 """ 

69 super().__init__(config, global_config, parent, methods) 

70 

71 check_required_params( 

72 self.config, 

73 [ 

74 "subscription", 

75 "resourceGroup", 

76 ], 

77 ) 

78 

79 # These parameters can come from command line as strings, so conversion is needed. 

80 self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL)) 

81 self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT)) 

82 self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) 

83 self._total_retries = int( 

84 self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES) 

85 ) 

86 self._backoff_factor = float( 

87 self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR) 

88 ) 

89 

90 self._deploy_template = {} 

91 self._deploy_params = {} 

92 if self.config.get("deploymentTemplatePath") is not None: 

93 # TODO: Provide external schema validation? 

94 template = self.config_loader_service.load_config( 

95 self.config["deploymentTemplatePath"], 

96 schema_type=None, 

97 ) 

98 assert template is not None and isinstance(template, dict) 

99 self._deploy_template = template 

100 

101 # Allow for recursive variable expansion as we do with global params and const_args. 

102 deploy_params = DictTemplater(self.config["deploymentTemplateParameters"]).expand_vars( 

103 extra_source_dict=global_config 

104 ) 

105 self._deploy_params = merge_parameters(dest=deploy_params, source=global_config) 

106 else: 

107 _LOG.info( 

108 "No deploymentTemplatePath provided. Deployment services will be unavailable.", 

109 ) 

110 

111 @property 

112 def deploy_params(self) -> dict: 

113 """Get the deployment parameters.""" 

114 return self._deploy_params 

115 

116 @abc.abstractmethod 

117 def _set_default_params(self, params: dict) -> dict: 

118 """ 

119 Optionally set some default parameters for the request. 

120 

121 Parameters 

122 ---------- 

123 params : dict 

124 The parameters. 

125 

126 Returns 

127 ------- 

128 dict 

129 The updated parameters. 

130 """ 

131 raise NotImplementedError("Should be overridden by subclass.") 

132 

133 def _get_session(self, params: dict) -> requests.Session: 

134 """Get a session object that includes automatic retries and headers for REST API 

135 calls. 

136 """ 

137 total_retries = params.get("requestTotalRetries", self._total_retries) 

138 backoff_factor = params.get("requestBackoffFactor", self._backoff_factor) 

139 session = requests.Session() 

140 session.mount( 

141 "https://", 

142 HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor)), 

143 ) 

144 session.headers.update(self._get_headers()) 

145 return session 

146 

147 def _get_headers(self) -> dict: 

148 """Get the headers for the REST API calls.""" 

149 assert self._parent is not None and isinstance( 

150 self._parent, SupportsAuth 

151 ), "Authorization service not provided. Include service-auth.jsonc?" 

152 return self._parent.get_auth_headers() 

153 

154 @staticmethod 

155 def _extract_arm_parameters(json_data: dict) -> dict: 

156 """ 

157 Extract parameters from the ARM Template REST response JSON. 

158 

159 Returns 

160 ------- 

161 parameters : dict 

162 Flat dictionary of parameters and their values. 

163 """ 

164 return { 

165 key: val.get("value") 

166 for (key, val) in json_data.get("properties", {}).get("parameters", {}).items() 

167 if val.get("value") is not None 

168 } 

169 

170 def _azure_rest_api_post_helper(self, params: dict, url: str) -> Tuple[Status, dict]: 

171 """ 

172 General pattern for performing an action on an Azure resource via its REST API. 

173 

174 Parameters 

175 ---------- 

176 params: dict 

177 Flat dictionary of (key, value) pairs of tunable parameters. 

178 url: str 

179 REST API url for the target to perform on the Azure VM. 

180 Should be a url that we intend to POST to. 

181 

182 Returns 

183 ------- 

184 result : (Status, dict={}) 

185 A pair of Status and result. 

186 Status is one of {PENDING, SUCCEEDED, FAILED} 

187 Result will have a value for 'asyncResultsUrl' if status is PENDING, 

188 and 'pollInterval' if suggested by the API. 

189 """ 

190 _LOG.debug("Request: POST %s", url) 

191 

192 response = requests.post(url, headers=self._get_headers(), timeout=self._request_timeout) 

193 _LOG.debug("Response: %s", response) 

194 

195 # Logical flow for async operations based on: 

196 # https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/async-operations 

197 if response.status_code == 200: 

198 return (Status.SUCCEEDED, params.copy()) 

199 elif response.status_code == 202: 

200 result = params.copy() 

201 if "Azure-AsyncOperation" in response.headers: 

202 result["asyncResultsUrl"] = response.headers.get("Azure-AsyncOperation") 

203 elif "Location" in response.headers: 

204 result["asyncResultsUrl"] = response.headers.get("Location") 

205 if "Retry-After" in response.headers: 

206 result["pollInterval"] = float(response.headers["Retry-After"]) 

207 

208 return (Status.PENDING, result) 

209 else: 

210 _LOG.error("Response: %s :: %s", response, response.text) 

211 # _LOG.error("Bad Request:\n%s", response.request.body) 

212 return (Status.FAILED, {}) 

213 

214 def _check_operation_status(self, params: dict) -> Tuple[Status, dict]: 

215 """ 

216 Checks the status of a pending operation on an Azure resource. 

217 

218 Parameters 

219 ---------- 

220 params: dict 

221 Flat dictionary of (key, value) pairs of tunable parameters. 

222 Must have the "asyncResultsUrl" key to get the results. 

223 If the key is not present, return Status.PENDING. 

224 

225 Returns 

226 ------- 

227 result : (Status, dict) 

228 A pair of Status and result. 

229 Status is one of {PENDING, RUNNING, SUCCEEDED, FAILED} 

230 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

231 """ 

232 url = params.get("asyncResultsUrl") 

233 if url is None: 

234 return Status.PENDING, {} 

235 

236 session = self._get_session(params) 

237 try: 

238 response = session.get(url, timeout=self._request_timeout) 

239 except requests.exceptions.ReadTimeout: 

240 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url) 

241 return Status.RUNNING, {} 

242 except requests.exceptions.RequestException as ex: 

243 _LOG.exception("Error in request checking operation status", exc_info=ex) 

244 return (Status.FAILED, {}) 

245 

246 if _LOG.isEnabledFor(logging.DEBUG): 

247 _LOG.debug( 

248 "Response: %s\n%s", 

249 response, 

250 json.dumps(response.json(), indent=2) if response.content else "", 

251 ) 

252 

253 if response.status_code == 200: 

254 output = response.json() 

255 status = output.get("status") 

256 if status == "InProgress": 

257 return Status.RUNNING, {} 

258 elif status == "Succeeded": 

259 return Status.SUCCEEDED, output 

260 

261 _LOG.error("Response: %s :: %s", response, response.text) 

262 return Status.FAILED, {} 

263 

264 def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]: 

265 """ 

266 Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or 

267 FAILED. Return TIMED_OUT when timing out. 

268 

269 Parameters 

270 ---------- 

271 params : dict 

272 Flat dictionary of (key, value) pairs of tunable parameters. 

273 is_setup : bool 

274 If True, wait for resource being deployed; otherwise, wait for 

275 successful deprovisioning. 

276 

277 Returns 

278 ------- 

279 result : (Status, dict) 

280 A pair of Status and result. 

281 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

282 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

283 """ 

284 params = self._set_default_params(params) 

285 _LOG.info( 

286 "Wait for %s to %s", 

287 params.get("deploymentName"), 

288 "provision" if is_setup else "deprovision", 

289 ) 

290 return self._wait_while(self._check_deployment, Status.PENDING, params) 

291 

292 def _wait_while( 

293 self, 

294 func: Callable[[dict], Tuple[Status, dict]], 

295 loop_status: Status, 

296 params: dict, 

297 ) -> Tuple[Status, dict]: 

298 """ 

299 Invoke `func` periodically while the status is equal to `loop_status`. Return 

300 TIMED_OUT when timing out. 

301 

302 Parameters 

303 ---------- 

304 func : a function 

305 A function that takes `params` and returns a pair of (Status, {}) 

306 loop_status: Status 

307 Steady state status - keep polling `func` while it returns `loop_status`. 

308 params : dict 

309 Flat dictionary of (key, value) pairs of tunable parameters. 

310 Requires deploymentName. 

311 

312 Returns 

313 ------- 

314 result : (Status, dict) 

315 A pair of Status and result. 

316 """ 

317 params = self._set_default_params(params) 

318 config = merge_parameters( 

319 dest=self.config.copy(), 

320 source=params, 

321 required_keys=["deploymentName"], 

322 ) 

323 

324 poll_period = params.get("pollInterval", self._poll_interval) 

325 

326 _LOG.debug( 

327 "Wait for %s status %s :: poll %.2f timeout %d s", 

328 config["deploymentName"], 

329 loop_status, 

330 poll_period, 

331 self._poll_timeout, 

332 ) 

333 

334 ts_timeout = time.time() + self._poll_timeout 

335 poll_delay = poll_period 

336 while True: 

337 # Wait for the suggested time first then check status 

338 ts_start = time.time() 

339 if ts_start >= ts_timeout: 

340 break 

341 

342 if poll_delay > 0: 

343 _LOG.debug("Sleep for: %.2f of %.2f s", poll_delay, poll_period) 

344 time.sleep(poll_delay) 

345 

346 (status, output) = func(params) 

347 if status != loop_status: 

348 return status, output 

349 

350 ts_end = time.time() 

351 poll_delay = poll_period - ts_end + ts_start 

352 

353 _LOG.warning("Request timed out: %s", params) 

354 return (Status.TIMED_OUT, {}) 

355 

356 def _check_deployment(self, params: dict) -> Tuple[Status, dict]: 

357 # pylint: disable=too-many-return-statements 

358 """ 

359 Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise. 

360 

361 Parameters 

362 ---------- 

363 _params : dict 

364 Flat dictionary of (key, value) pairs of tunable parameters. 

365 This parameter is not used; we need it for compatibility with 

366 other polling functions used in `_wait_while()`. 

367 

368 Returns 

369 ------- 

370 result : (Status, dict={}) 

371 A pair of Status and result. The result is always {}. 

372 Status is one of {SUCCEEDED, PENDING, FAILED} 

373 """ 

374 params = self._set_default_params(params) 

375 config = merge_parameters( 

376 dest=self.config.copy(), 

377 source=params, 

378 required_keys=[ 

379 "subscription", 

380 "resourceGroup", 

381 "deploymentName", 

382 ], 

383 ) 

384 

385 _LOG.info("Check deployment: %s", config["deploymentName"]) 

386 

387 url = self._URL_DEPLOY.format( 

388 subscription=config["subscription"], 

389 resource_group=config["resourceGroup"], 

390 deployment_name=config["deploymentName"], 

391 ) 

392 

393 session = self._get_session(params) 

394 try: 

395 response = session.get(url, timeout=self._request_timeout) 

396 except requests.exceptions.ReadTimeout: 

397 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url) 

398 return Status.RUNNING, {} 

399 except requests.exceptions.RequestException as ex: 

400 _LOG.exception("Error in request checking deployment", exc_info=ex) 

401 return (Status.FAILED, {}) 

402 

403 _LOG.debug("Response: %s", response) 

404 

405 if response.status_code == 200: 

406 output = response.json() 

407 state = output.get("properties", {}).get("provisioningState", "") 

408 

409 if state == "Succeeded": 

410 return (Status.SUCCEEDED, {}) 

411 elif state in {"Accepted", "Creating", "Deleting", "Running", "Updating"}: 

412 return (Status.PENDING, {}) 

413 else: 

414 _LOG.error("Response: %s :: %s", response, json.dumps(output, indent=2)) 

415 return (Status.FAILED, {}) 

416 elif response.status_code == 404: 

417 return (Status.PENDING, {}) 

418 

419 _LOG.error("Response: %s :: %s", response, response.text) 

420 return (Status.FAILED, {}) 

421 

422 def _provision_resource(self, params: dict) -> Tuple[Status, dict]: 

423 """ 

424 Attempts to (re)deploy a resource. 

425 

426 Parameters 

427 ---------- 

428 params : dict 

429 Flat dictionary of (key, value) pairs of tunable parameters. 

430 Tunables are variable parameters that, together with the 

431 Environment configuration, are sufficient to provision the resource. 

432 

433 Returns 

434 ------- 

435 result : (Status, dict={}) 

436 A pair of Status and result. The result is the input `params` plus the 

437 parameters extracted from the response JSON, or {} if the status is FAILED. 

438 Status is one of {PENDING, SUCCEEDED, FAILED} 

439 """ 

440 if not self._deploy_template: 

441 raise ValueError(f"Missing deployment template: {self}") 

442 params = self._set_default_params(params) 

443 config = merge_parameters( 

444 dest=self.config.copy(), 

445 source=params, 

446 required_keys=["deploymentName"], 

447 ) 

448 _LOG.info("Deploy: %s :: %s", config["deploymentName"], params) 

449 

450 params = merge_parameters(dest=self._deploy_params.copy(), source=params) 

451 if _LOG.isEnabledFor(logging.DEBUG): 

452 _LOG.debug( 

453 "Deploy: %s merged params ::\n%s", 

454 config["deploymentName"], 

455 json.dumps(params, indent=2), 

456 ) 

457 

458 url = self._URL_DEPLOY.format( 

459 subscription=config["subscription"], 

460 resource_group=config["resourceGroup"], 

461 deployment_name=config["deploymentName"], 

462 ) 

463 

464 json_req = { 

465 "properties": { 

466 "mode": "Incremental", 

467 "template": self._deploy_template, 

468 "parameters": { 

469 key: {"value": val} 

470 for (key, val) in params.items() 

471 if key in self._deploy_template.get("parameters", {}) 

472 }, 

473 } 

474 } 

475 

476 if _LOG.isEnabledFor(logging.DEBUG): 

477 _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2)) 

478 

479 response = requests.put( 

480 url, 

481 json=json_req, 

482 headers=self._get_headers(), 

483 timeout=self._request_timeout, 

484 ) 

485 

486 if _LOG.isEnabledFor(logging.DEBUG): 

487 _LOG.debug( 

488 "Response: %s\n%s", 

489 response, 

490 json.dumps(response.json(), indent=2) if response.content else "", 

491 ) 

492 else: 

493 _LOG.info("Response: %s", response) 

494 

495 if response.status_code == 200: 

496 return (Status.PENDING, config) 

497 elif response.status_code == 201: 

498 output = self._extract_arm_parameters(response.json()) 

499 if _LOG.isEnabledFor(logging.DEBUG): 

500 _LOG.debug("Extracted parameters:\n%s", json.dumps(output, indent=2)) 

501 params.update(output) 

502 params.setdefault("asyncResultsUrl", url) 

503 params.setdefault("deploymentName", config["deploymentName"]) 

504 return (Status.PENDING, params) 

505 else: 

506 _LOG.error("Response: %s :: %s", response, response.text) 

507 # _LOG.error("Bad Request:\n%s", response.request.body) 

508 return (Status.FAILED, {})