Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py: 76%

179 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Base class for certain Azure Services classes that do deployments.""" 

6 

7import abc 

8import json 

9import logging 

10import time 

11from collections.abc import Callable 

12from typing import Any 

13 

14import requests 

15from requests.adapters import HTTPAdapter, Retry 

16 

17from mlos_bench.dict_templater import DictTemplater 

18from mlos_bench.environments.status import Status 

19from mlos_bench.services.base_service import Service 

20from mlos_bench.services.types.authenticator_type import SupportsAuth 

21from mlos_bench.util import check_required_params, merge_parameters 

22 

23_LOG = logging.getLogger(__name__) 

24 

25 

26class AzureDeploymentService(Service, metaclass=abc.ABCMeta): 

27 """Helper methods to manage and deploy Azure resources via REST APIs.""" 

28 

29 _POLL_INTERVAL = 4 # seconds 

30 _POLL_TIMEOUT = 300 # seconds 

31 _REQUEST_TIMEOUT = 5 # seconds 

32 _REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request 

33 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries})) 

34 _REQUEST_RETRY_BACKOFF_FACTOR = 0.3 

35 

36 # Azure Resources Deployment REST API as described in 

37 # https://docs.microsoft.com/en-us/rest/api/resources/deployments 

38 

39 _URL_DEPLOY = ( 

40 "https://management.azure.com" 

41 "/subscriptions/{subscription}" 

42 "/resourceGroups/{resource_group}" 

43 "/providers/Microsoft.Resources" 

44 "/deployments/{deployment_name}" 

45 "?api-version=2022-05-01" 

46 ) 

47 

48 def __init__( 

49 self, 

50 config: dict[str, Any] | None = None, 

51 global_config: dict[str, Any] | None = None, 

52 parent: Service | None = None, 

53 methods: dict[str, Callable] | list[Callable] | None = None, 

54 ): 

55 """ 

56 Create a new instance of an Azure Services proxy. 

57 

58 Parameters 

59 ---------- 

60 config : dict 

61 Free-format dictionary that contains the benchmark environment 

62 configuration. 

63 global_config : dict 

64 Free-format dictionary of global parameters. 

65 parent : Service 

66 Parent service that can provide mixin functions. 

67 methods : Union[dict[str, Callable], list[Callable], None] 

68 New methods to register with the service. 

69 """ 

70 super().__init__(config, global_config, parent, methods) 

71 

72 check_required_params( 

73 self.config, 

74 [ 

75 "subscription", 

76 "resourceGroup", 

77 ], 

78 ) 

79 

80 # These parameters can come from command line as strings, so conversion is needed. 

81 self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL)) 

82 self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT)) 

83 self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) 

84 self._total_retries = int( 

85 self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES) 

86 ) 

87 self._backoff_factor = float( 

88 self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR) 

89 ) 

90 

91 self._deploy_template = {} 

92 self._deploy_params = {} 

93 if self.config.get("deploymentTemplatePath") is not None: 

94 # TODO: Provide external schema validation? 

95 template = self.config_loader_service.load_config( 

96 self.config["deploymentTemplatePath"], 

97 schema_type=None, 

98 ) 

99 assert template is not None and isinstance(template, dict) 

100 self._deploy_template = template 

101 

102 # Allow for recursive variable expansion as we do with global params and const_args. 

103 deploy_params = DictTemplater(self.config["deploymentTemplateParameters"]).expand_vars( 

104 extra_source_dict=global_config 

105 ) 

106 self._deploy_params = merge_parameters(dest=deploy_params, source=global_config) 

107 else: 

108 _LOG.info( 

109 "No deploymentTemplatePath provided. Deployment services will be unavailable.", 

110 ) 

111 

112 @property 

113 def deploy_params(self) -> dict: 

114 """Get the deployment parameters.""" 

115 return self._deploy_params 

116 

117 @abc.abstractmethod 

118 def _set_default_params(self, params: dict) -> dict: 

119 """ 

120 Optionally set some default parameters for the request. 

121 

122 Parameters 

123 ---------- 

124 params : dict 

125 The parameters. 

126 

127 Returns 

128 ------- 

129 dict 

130 The updated parameters. 

131 """ 

132 raise NotImplementedError("Should be overridden by subclass.") 

133 

134 def _get_session(self, params: dict) -> requests.Session: 

135 """Get a session object that includes automatic retries and headers for REST API 

136 calls. 

137 """ 

138 total_retries = params.get("requestTotalRetries", self._total_retries) 

139 backoff_factor = params.get("requestBackoffFactor", self._backoff_factor) 

140 session = requests.Session() 

141 session.mount( 

142 "https://", 

143 HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor)), 

144 ) 

145 session.headers.update(self._get_headers()) 

146 return session 

147 

148 def _get_headers(self) -> dict: 

149 """Get the headers for the REST API calls.""" 

150 assert self._parent is not None and isinstance( 

151 self._parent, SupportsAuth 

152 ), "Authorization service not provided. Include service-auth.jsonc?" 

153 return self._parent.get_auth_headers() 

154 

155 @staticmethod 

156 def _extract_arm_parameters(json_data: dict) -> dict: 

157 """ 

158 Extract parameters from the ARM Template REST response JSON. 

159 

160 Returns 

161 ------- 

162 parameters : dict 

163 Flat dictionary of parameters and their values. 

164 """ 

165 return { 

166 key: val.get("value") 

167 for (key, val) in json_data.get("properties", {}).get("parameters", {}).items() 

168 if val.get("value") is not None 

169 } 

170 

171 def _azure_rest_api_post_helper(self, params: dict, url: str) -> tuple[Status, dict]: 

172 """ 

173 General pattern for performing an action on an Azure resource via its REST API. 

174 

175 Parameters 

176 ---------- 

177 params: dict 

178 Flat dictionary of (key, value) pairs of tunable parameters. 

179 url: str 

180 REST API url for the target to perform on the Azure VM. 

181 Should be a url that we intend to POST to. 

182 

183 Returns 

184 ------- 

185 result : (Status, dict={}) 

186 A pair of Status and result. 

187 Status is one of {PENDING, SUCCEEDED, FAILED} 

188 Result will have a value for 'asyncResultsUrl' if status is PENDING, 

189 and 'pollInterval' if suggested by the API. 

190 """ 

191 _LOG.debug("Request: POST %s", url) 

192 

193 response = requests.post(url, headers=self._get_headers(), timeout=self._request_timeout) 

194 _LOG.debug("Response: %s", response) 

195 

196 # Logical flow for async operations based on: 

197 # https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/async-operations 

198 if response.status_code == 200: 

199 return (Status.SUCCEEDED, params.copy()) 

200 elif response.status_code == 202: 

201 result = params.copy() 

202 if "Azure-AsyncOperation" in response.headers: 

203 result["asyncResultsUrl"] = response.headers.get("Azure-AsyncOperation") 

204 elif "Location" in response.headers: 

205 result["asyncResultsUrl"] = response.headers.get("Location") 

206 if "Retry-After" in response.headers: 

207 result["pollInterval"] = float(response.headers["Retry-After"]) 

208 

209 return (Status.PENDING, result) 

210 else: 

211 _LOG.error("Response: %s :: %s", response, response.text) 

212 # _LOG.error("Bad Request:\n%s", response.request.body) 

213 return (Status.FAILED, {}) 

214 

215 def _check_operation_status(self, params: dict) -> tuple[Status, dict]: 

216 """ 

217 Checks the status of a pending operation on an Azure resource. 

218 

219 Parameters 

220 ---------- 

221 params: dict 

222 Flat dictionary of (key, value) pairs of tunable parameters. 

223 Must have the "asyncResultsUrl" key to get the results. 

224 If the key is not present, return Status.PENDING. 

225 

226 Returns 

227 ------- 

228 result : (Status, dict) 

229 A pair of Status and result. 

230 Status is one of {PENDING, RUNNING, SUCCEEDED, FAILED} 

231 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

232 """ 

233 url = params.get("asyncResultsUrl") 

234 if url is None: 

235 return Status.PENDING, {} 

236 

237 session = self._get_session(params) 

238 try: 

239 response = session.get(url, timeout=self._request_timeout) 

240 except requests.exceptions.ReadTimeout: 

241 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url) 

242 return Status.RUNNING, {} 

243 except requests.exceptions.RequestException as ex: 

244 _LOG.exception("Error in request checking operation status", exc_info=ex) 

245 return (Status.FAILED, {}) 

246 

247 if _LOG.isEnabledFor(logging.DEBUG): 

248 _LOG.debug( 

249 "Response: %s\n%s", 

250 response, 

251 json.dumps(response.json(), indent=2) if response.content else "", 

252 ) 

253 

254 if response.status_code == 200: 

255 output = response.json() 

256 status = output.get("status") 

257 if status == "InProgress": 

258 return Status.RUNNING, {} 

259 elif status == "Succeeded": 

260 return Status.SUCCEEDED, output 

261 

262 _LOG.error("Response: %s :: %s", response, response.text) 

263 return Status.FAILED, {} 

264 

265 def _wait_deployment(self, params: dict, *, is_setup: bool) -> tuple[Status, dict]: 

266 """ 

267 Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or 

268 FAILED. Return TIMED_OUT when timing out. 

269 

270 Parameters 

271 ---------- 

272 params : dict 

273 Flat dictionary of (key, value) pairs of tunable parameters. 

274 is_setup : bool 

275 If True, wait for resource being deployed; otherwise, wait for 

276 successful deprovisioning. 

277 

278 Returns 

279 ------- 

280 result : (Status, dict) 

281 A pair of Status and result. 

282 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

283 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

284 """ 

285 params = self._set_default_params(params) 

286 _LOG.info( 

287 "Wait for %s to %s", 

288 params.get("deploymentName"), 

289 "provision" if is_setup else "deprovision", 

290 ) 

291 return self._wait_while(self._check_deployment, Status.PENDING, params) 

292 

293 def _wait_while( 

294 self, 

295 func: Callable[[dict], tuple[Status, dict]], 

296 loop_status: Status, 

297 params: dict, 

298 ) -> tuple[Status, dict]: 

299 """ 

300 Invoke `func` periodically while the status is equal to `loop_status`. Return 

301 TIMED_OUT when timing out. 

302 

303 Parameters 

304 ---------- 

305 func : a function 

306 A function that takes `params` and returns a pair of (Status, {}) 

307 loop_status: Status 

308 Steady state status - keep polling `func` while it returns `loop_status`. 

309 params : dict 

310 Flat dictionary of (key, value) pairs of tunable parameters. 

311 Requires deploymentName. 

312 

313 Returns 

314 ------- 

315 result : (Status, dict) 

316 A pair of Status and result. 

317 """ 

318 params = self._set_default_params(params) 

319 config = merge_parameters( 

320 dest=self.config.copy(), 

321 source=params, 

322 required_keys=["deploymentName"], 

323 ) 

324 

325 poll_period = params.get("pollInterval", self._poll_interval) 

326 

327 _LOG.debug( 

328 "Wait for %s status %s :: poll %.2f timeout %d s", 

329 config["deploymentName"], 

330 loop_status, 

331 poll_period, 

332 self._poll_timeout, 

333 ) 

334 

335 ts_timeout = time.time() + self._poll_timeout 

336 poll_delay = poll_period 

337 while True: 

338 # Wait for the suggested time first then check status 

339 ts_start = time.time() 

340 if ts_start >= ts_timeout: 

341 break 

342 

343 if poll_delay > 0: 

344 _LOG.debug("Sleep for: %.2f of %.2f s", poll_delay, poll_period) 

345 time.sleep(poll_delay) 

346 

347 (status, output) = func(params) 

348 if status != loop_status: 

349 return status, output 

350 

351 ts_end = time.time() 

352 poll_delay = poll_period - ts_end + ts_start 

353 

354 _LOG.warning("Request timed out: %s", params) 

355 return (Status.TIMED_OUT, {}) 

356 

357 def _check_deployment(self, params: dict) -> tuple[Status, dict]: 

358 # pylint: disable=too-many-return-statements 

359 """ 

360 Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise. 

361 

362 Parameters 

363 ---------- 

364 _params : dict 

365 Flat dictionary of (key, value) pairs of tunable parameters. 

366 This parameter is not used; we need it for compatibility with 

367 other polling functions used in `_wait_while()`. 

368 

369 Returns 

370 ------- 

371 result : (Status, dict={}) 

372 A pair of Status and result. The result is always {}. 

373 Status is one of {SUCCEEDED, PENDING, FAILED} 

374 """ 

375 params = self._set_default_params(params) 

376 config = merge_parameters( 

377 dest=self.config.copy(), 

378 source=params, 

379 required_keys=[ 

380 "subscription", 

381 "resourceGroup", 

382 "deploymentName", 

383 ], 

384 ) 

385 

386 _LOG.info("Check deployment: %s", config["deploymentName"]) 

387 

388 url = self._URL_DEPLOY.format( 

389 subscription=config["subscription"], 

390 resource_group=config["resourceGroup"], 

391 deployment_name=config["deploymentName"], 

392 ) 

393 

394 session = self._get_session(params) 

395 try: 

396 response = session.get(url, timeout=self._request_timeout) 

397 except requests.exceptions.ReadTimeout: 

398 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url) 

399 return Status.RUNNING, {} 

400 except requests.exceptions.RequestException as ex: 

401 _LOG.exception("Error in request checking deployment", exc_info=ex) 

402 return (Status.FAILED, {}) 

403 

404 _LOG.debug("Response: %s", response) 

405 

406 if response.status_code == 200: 

407 output = response.json() 

408 state = output.get("properties", {}).get("provisioningState", "") 

409 

410 if state == "Succeeded": 

411 return (Status.SUCCEEDED, {}) 

412 elif state in {"Accepted", "Creating", "Deleting", "Running", "Updating"}: 

413 return (Status.PENDING, {}) 

414 else: 

415 _LOG.error("Response: %s :: %s", response, json.dumps(output, indent=2)) 

416 return (Status.FAILED, {}) 

417 elif response.status_code == 404: 

418 return (Status.PENDING, {}) 

419 

420 _LOG.error("Response: %s :: %s", response, response.text) 

421 return (Status.FAILED, {}) 

422 

423 def _provision_resource(self, params: dict) -> tuple[Status, dict]: 

424 """ 

425 Attempts to (re)deploy a resource. 

426 

427 Parameters 

428 ---------- 

429 params : dict 

430 Flat dictionary of (key, value) pairs of tunable parameters. 

431 Tunables are variable parameters that, together with the 

432 Environment configuration, are sufficient to provision the resource. 

433 

434 Returns 

435 ------- 

436 result : (Status, dict={}) 

437 A pair of Status and result. The result is the input `params` plus the 

438 parameters extracted from the response JSON, or {} if the status is FAILED. 

439 Status is one of {PENDING, SUCCEEDED, FAILED} 

440 """ 

441 if not self._deploy_template: 

442 raise ValueError(f"Missing deployment template: {self}") 

443 params = self._set_default_params(params) 

444 config = merge_parameters( 

445 dest=self.config.copy(), 

446 source=params, 

447 required_keys=["deploymentName"], 

448 ) 

449 _LOG.info("Deploy: %s :: %s", config["deploymentName"], params) 

450 

451 params = merge_parameters(dest=self._deploy_params.copy(), source=params) 

452 if _LOG.isEnabledFor(logging.DEBUG): 

453 _LOG.debug( 

454 "Deploy: %s merged params ::\n%s", 

455 config["deploymentName"], 

456 json.dumps(params, indent=2), 

457 ) 

458 

459 url = self._URL_DEPLOY.format( 

460 subscription=config["subscription"], 

461 resource_group=config["resourceGroup"], 

462 deployment_name=config["deploymentName"], 

463 ) 

464 

465 json_req = { 

466 "properties": { 

467 "mode": "Incremental", 

468 "template": self._deploy_template, 

469 "parameters": { 

470 key: {"value": val} 

471 for (key, val) in params.items() 

472 if key in self._deploy_template.get("parameters", {}) 

473 }, 

474 } 

475 } 

476 

477 if _LOG.isEnabledFor(logging.DEBUG): 

478 _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2)) 

479 

480 response = requests.put( 

481 url, 

482 json=json_req, 

483 headers=self._get_headers(), 

484 timeout=self._request_timeout, 

485 ) 

486 

487 if _LOG.isEnabledFor(logging.DEBUG): 

488 _LOG.debug( 

489 "Response: %s\n%s", 

490 response, 

491 json.dumps(response.json(), indent=2) if response.content else "", 

492 ) 

493 else: 

494 _LOG.info("Response: %s", response) 

495 

496 if response.status_code == 200: 

497 return (Status.PENDING, config) 

498 elif response.status_code == 201: 

499 output = self._extract_arm_parameters(response.json()) 

500 if _LOG.isEnabledFor(logging.DEBUG): 

501 _LOG.debug("Extracted parameters:\n%s", json.dumps(output, indent=2)) 

502 params.update(output) 

503 params.setdefault("asyncResultsUrl", url) 

504 params.setdefault("deploymentName", config["deploymentName"]) 

505 return (Status.PENDING, params) 

506 else: 

507 _LOG.error("Response: %s :: %s", response, response.text) 

508 # _LOG.error("Bad Request:\n%s", response.request.body) 

509 return (Status.FAILED, {})