Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py: 76%
179 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Base class for certain Azure Services classes that do deployments."""
7import abc
8import json
9import logging
10import time
11from collections.abc import Callable
12from typing import Any
14import requests
15from requests.adapters import HTTPAdapter, Retry
17from mlos_bench.dict_templater import DictTemplater
18from mlos_bench.environments.status import Status
19from mlos_bench.services.base_service import Service
20from mlos_bench.services.types.authenticator_type import SupportsAuth
21from mlos_bench.util import check_required_params, merge_parameters
23_LOG = logging.getLogger(__name__)
26class AzureDeploymentService(Service, metaclass=abc.ABCMeta):
27 """Helper methods to manage and deploy Azure resources via REST APIs."""
29 _POLL_INTERVAL = 4 # seconds
30 _POLL_TIMEOUT = 300 # seconds
31 _REQUEST_TIMEOUT = 5 # seconds
32 _REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request
33 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries}))
34 _REQUEST_RETRY_BACKOFF_FACTOR = 0.3
36 # Azure Resources Deployment REST API as described in
37 # https://docs.microsoft.com/en-us/rest/api/resources/deployments
39 _URL_DEPLOY = (
40 "https://management.azure.com"
41 "/subscriptions/{subscription}"
42 "/resourceGroups/{resource_group}"
43 "/providers/Microsoft.Resources"
44 "/deployments/{deployment_name}"
45 "?api-version=2022-05-01"
46 )
48 def __init__(
49 self,
50 config: dict[str, Any] | None = None,
51 global_config: dict[str, Any] | None = None,
52 parent: Service | None = None,
53 methods: dict[str, Callable] | list[Callable] | None = None,
54 ):
55 """
56 Create a new instance of an Azure Services proxy.
58 Parameters
59 ----------
60 config : dict
61 Free-format dictionary that contains the benchmark environment
62 configuration.
63 global_config : dict
64 Free-format dictionary of global parameters.
65 parent : Service
66 Parent service that can provide mixin functions.
67 methods : Union[dict[str, Callable], list[Callable], None]
68 New methods to register with the service.
69 """
70 super().__init__(config, global_config, parent, methods)
72 check_required_params(
73 self.config,
74 [
75 "subscription",
76 "resourceGroup",
77 ],
78 )
80 # These parameters can come from command line as strings, so conversion is needed.
81 self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL))
82 self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT))
83 self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT))
84 self._total_retries = int(
85 self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES)
86 )
87 self._backoff_factor = float(
88 self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR)
89 )
91 self._deploy_template = {}
92 self._deploy_params = {}
93 if self.config.get("deploymentTemplatePath") is not None:
94 # TODO: Provide external schema validation?
95 template = self.config_loader_service.load_config(
96 self.config["deploymentTemplatePath"],
97 schema_type=None,
98 )
99 assert template is not None and isinstance(template, dict)
100 self._deploy_template = template
102 # Allow for recursive variable expansion as we do with global params and const_args.
103 deploy_params = DictTemplater(self.config["deploymentTemplateParameters"]).expand_vars(
104 extra_source_dict=global_config
105 )
106 self._deploy_params = merge_parameters(dest=deploy_params, source=global_config)
107 else:
108 _LOG.info(
109 "No deploymentTemplatePath provided. Deployment services will be unavailable.",
110 )
112 @property
113 def deploy_params(self) -> dict:
114 """Get the deployment parameters."""
115 return self._deploy_params
117 @abc.abstractmethod
118 def _set_default_params(self, params: dict) -> dict:
119 """
120 Optionally set some default parameters for the request.
122 Parameters
123 ----------
124 params : dict
125 The parameters.
127 Returns
128 -------
129 dict
130 The updated parameters.
131 """
132 raise NotImplementedError("Should be overridden by subclass.")
134 def _get_session(self, params: dict) -> requests.Session:
135 """Get a session object that includes automatic retries and headers for REST API
136 calls.
137 """
138 total_retries = params.get("requestTotalRetries", self._total_retries)
139 backoff_factor = params.get("requestBackoffFactor", self._backoff_factor)
140 session = requests.Session()
141 session.mount(
142 "https://",
143 HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor)),
144 )
145 session.headers.update(self._get_headers())
146 return session
148 def _get_headers(self) -> dict:
149 """Get the headers for the REST API calls."""
150 assert self._parent is not None and isinstance(
151 self._parent, SupportsAuth
152 ), "Authorization service not provided. Include service-auth.jsonc?"
153 return self._parent.get_auth_headers()
155 @staticmethod
156 def _extract_arm_parameters(json_data: dict) -> dict:
157 """
158 Extract parameters from the ARM Template REST response JSON.
160 Returns
161 -------
162 parameters : dict
163 Flat dictionary of parameters and their values.
164 """
165 return {
166 key: val.get("value")
167 for (key, val) in json_data.get("properties", {}).get("parameters", {}).items()
168 if val.get("value") is not None
169 }
171 def _azure_rest_api_post_helper(self, params: dict, url: str) -> tuple[Status, dict]:
172 """
173 General pattern for performing an action on an Azure resource via its REST API.
175 Parameters
176 ----------
177 params: dict
178 Flat dictionary of (key, value) pairs of tunable parameters.
179 url: str
180 REST API url for the target to perform on the Azure VM.
181 Should be a url that we intend to POST to.
183 Returns
184 -------
185 result : (Status, dict={})
186 A pair of Status and result.
187 Status is one of {PENDING, SUCCEEDED, FAILED}
188 Result will have a value for 'asyncResultsUrl' if status is PENDING,
189 and 'pollInterval' if suggested by the API.
190 """
191 _LOG.debug("Request: POST %s", url)
193 response = requests.post(url, headers=self._get_headers(), timeout=self._request_timeout)
194 _LOG.debug("Response: %s", response)
196 # Logical flow for async operations based on:
197 # https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/async-operations
198 if response.status_code == 200:
199 return (Status.SUCCEEDED, params.copy())
200 elif response.status_code == 202:
201 result = params.copy()
202 if "Azure-AsyncOperation" in response.headers:
203 result["asyncResultsUrl"] = response.headers.get("Azure-AsyncOperation")
204 elif "Location" in response.headers:
205 result["asyncResultsUrl"] = response.headers.get("Location")
206 if "Retry-After" in response.headers:
207 result["pollInterval"] = float(response.headers["Retry-After"])
209 return (Status.PENDING, result)
210 else:
211 _LOG.error("Response: %s :: %s", response, response.text)
212 # _LOG.error("Bad Request:\n%s", response.request.body)
213 return (Status.FAILED, {})
215 def _check_operation_status(self, params: dict) -> tuple[Status, dict]:
216 """
217 Checks the status of a pending operation on an Azure resource.
219 Parameters
220 ----------
221 params: dict
222 Flat dictionary of (key, value) pairs of tunable parameters.
223 Must have the "asyncResultsUrl" key to get the results.
224 If the key is not present, return Status.PENDING.
226 Returns
227 -------
228 result : (Status, dict)
229 A pair of Status and result.
230 Status is one of {PENDING, RUNNING, SUCCEEDED, FAILED}
231 Result is info on the operation runtime if SUCCEEDED, otherwise {}.
232 """
233 url = params.get("asyncResultsUrl")
234 if url is None:
235 return Status.PENDING, {}
237 session = self._get_session(params)
238 try:
239 response = session.get(url, timeout=self._request_timeout)
240 except requests.exceptions.ReadTimeout:
241 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url)
242 return Status.RUNNING, {}
243 except requests.exceptions.RequestException as ex:
244 _LOG.exception("Error in request checking operation status", exc_info=ex)
245 return (Status.FAILED, {})
247 if _LOG.isEnabledFor(logging.DEBUG):
248 _LOG.debug(
249 "Response: %s\n%s",
250 response,
251 json.dumps(response.json(), indent=2) if response.content else "",
252 )
254 if response.status_code == 200:
255 output = response.json()
256 status = output.get("status")
257 if status == "InProgress":
258 return Status.RUNNING, {}
259 elif status == "Succeeded":
260 return Status.SUCCEEDED, output
262 _LOG.error("Response: %s :: %s", response, response.text)
263 return Status.FAILED, {}
265 def _wait_deployment(self, params: dict, *, is_setup: bool) -> tuple[Status, dict]:
266 """
267 Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or
268 FAILED. Return TIMED_OUT when timing out.
270 Parameters
271 ----------
272 params : dict
273 Flat dictionary of (key, value) pairs of tunable parameters.
274 is_setup : bool
275 If True, wait for resource being deployed; otherwise, wait for
276 successful deprovisioning.
278 Returns
279 -------
280 result : (Status, dict)
281 A pair of Status and result.
282 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
283 Result is info on the operation runtime if SUCCEEDED, otherwise {}.
284 """
285 params = self._set_default_params(params)
286 _LOG.info(
287 "Wait for %s to %s",
288 params.get("deploymentName"),
289 "provision" if is_setup else "deprovision",
290 )
291 return self._wait_while(self._check_deployment, Status.PENDING, params)
293 def _wait_while(
294 self,
295 func: Callable[[dict], tuple[Status, dict]],
296 loop_status: Status,
297 params: dict,
298 ) -> tuple[Status, dict]:
299 """
300 Invoke `func` periodically while the status is equal to `loop_status`. Return
301 TIMED_OUT when timing out.
303 Parameters
304 ----------
305 func : a function
306 A function that takes `params` and returns a pair of (Status, {})
307 loop_status: Status
308 Steady state status - keep polling `func` while it returns `loop_status`.
309 params : dict
310 Flat dictionary of (key, value) pairs of tunable parameters.
311 Requires deploymentName.
313 Returns
314 -------
315 result : (Status, dict)
316 A pair of Status and result.
317 """
318 params = self._set_default_params(params)
319 config = merge_parameters(
320 dest=self.config.copy(),
321 source=params,
322 required_keys=["deploymentName"],
323 )
325 poll_period = params.get("pollInterval", self._poll_interval)
327 _LOG.debug(
328 "Wait for %s status %s :: poll %.2f timeout %d s",
329 config["deploymentName"],
330 loop_status,
331 poll_period,
332 self._poll_timeout,
333 )
335 ts_timeout = time.time() + self._poll_timeout
336 poll_delay = poll_period
337 while True:
338 # Wait for the suggested time first then check status
339 ts_start = time.time()
340 if ts_start >= ts_timeout:
341 break
343 if poll_delay > 0:
344 _LOG.debug("Sleep for: %.2f of %.2f s", poll_delay, poll_period)
345 time.sleep(poll_delay)
347 (status, output) = func(params)
348 if status != loop_status:
349 return status, output
351 ts_end = time.time()
352 poll_delay = poll_period - ts_end + ts_start
354 _LOG.warning("Request timed out: %s", params)
355 return (Status.TIMED_OUT, {})
357 def _check_deployment(self, params: dict) -> tuple[Status, dict]:
358 # pylint: disable=too-many-return-statements
359 """
360 Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise.
362 Parameters
363 ----------
364 _params : dict
365 Flat dictionary of (key, value) pairs of tunable parameters.
366 This parameter is not used; we need it for compatibility with
367 other polling functions used in `_wait_while()`.
369 Returns
370 -------
371 result : (Status, dict={})
372 A pair of Status and result. The result is always {}.
373 Status is one of {SUCCEEDED, PENDING, FAILED}
374 """
375 params = self._set_default_params(params)
376 config = merge_parameters(
377 dest=self.config.copy(),
378 source=params,
379 required_keys=[
380 "subscription",
381 "resourceGroup",
382 "deploymentName",
383 ],
384 )
386 _LOG.info("Check deployment: %s", config["deploymentName"])
388 url = self._URL_DEPLOY.format(
389 subscription=config["subscription"],
390 resource_group=config["resourceGroup"],
391 deployment_name=config["deploymentName"],
392 )
394 session = self._get_session(params)
395 try:
396 response = session.get(url, timeout=self._request_timeout)
397 except requests.exceptions.ReadTimeout:
398 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url)
399 return Status.RUNNING, {}
400 except requests.exceptions.RequestException as ex:
401 _LOG.exception("Error in request checking deployment", exc_info=ex)
402 return (Status.FAILED, {})
404 _LOG.debug("Response: %s", response)
406 if response.status_code == 200:
407 output = response.json()
408 state = output.get("properties", {}).get("provisioningState", "")
410 if state == "Succeeded":
411 return (Status.SUCCEEDED, {})
412 elif state in {"Accepted", "Creating", "Deleting", "Running", "Updating"}:
413 return (Status.PENDING, {})
414 else:
415 _LOG.error("Response: %s :: %s", response, json.dumps(output, indent=2))
416 return (Status.FAILED, {})
417 elif response.status_code == 404:
418 return (Status.PENDING, {})
420 _LOG.error("Response: %s :: %s", response, response.text)
421 return (Status.FAILED, {})
423 def _provision_resource(self, params: dict) -> tuple[Status, dict]:
424 """
425 Attempts to (re)deploy a resource.
427 Parameters
428 ----------
429 params : dict
430 Flat dictionary of (key, value) pairs of tunable parameters.
431 Tunables are variable parameters that, together with the
432 Environment configuration, are sufficient to provision the resource.
434 Returns
435 -------
436 result : (Status, dict={})
437 A pair of Status and result. The result is the input `params` plus the
438 parameters extracted from the response JSON, or {} if the status is FAILED.
439 Status is one of {PENDING, SUCCEEDED, FAILED}
440 """
441 if not self._deploy_template:
442 raise ValueError(f"Missing deployment template: {self}")
443 params = self._set_default_params(params)
444 config = merge_parameters(
445 dest=self.config.copy(),
446 source=params,
447 required_keys=["deploymentName"],
448 )
449 _LOG.info("Deploy: %s :: %s", config["deploymentName"], params)
451 params = merge_parameters(dest=self._deploy_params.copy(), source=params)
452 if _LOG.isEnabledFor(logging.DEBUG):
453 _LOG.debug(
454 "Deploy: %s merged params ::\n%s",
455 config["deploymentName"],
456 json.dumps(params, indent=2),
457 )
459 url = self._URL_DEPLOY.format(
460 subscription=config["subscription"],
461 resource_group=config["resourceGroup"],
462 deployment_name=config["deploymentName"],
463 )
465 json_req = {
466 "properties": {
467 "mode": "Incremental",
468 "template": self._deploy_template,
469 "parameters": {
470 key: {"value": val}
471 for (key, val) in params.items()
472 if key in self._deploy_template.get("parameters", {})
473 },
474 }
475 }
477 if _LOG.isEnabledFor(logging.DEBUG):
478 _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2))
480 response = requests.put(
481 url,
482 json=json_req,
483 headers=self._get_headers(),
484 timeout=self._request_timeout,
485 )
487 if _LOG.isEnabledFor(logging.DEBUG):
488 _LOG.debug(
489 "Response: %s\n%s",
490 response,
491 json.dumps(response.json(), indent=2) if response.content else "",
492 )
493 else:
494 _LOG.info("Response: %s", response)
496 if response.status_code == 200:
497 return (Status.PENDING, config)
498 elif response.status_code == 201:
499 output = self._extract_arm_parameters(response.json())
500 if _LOG.isEnabledFor(logging.DEBUG):
501 _LOG.debug("Extracted parameters:\n%s", json.dumps(output, indent=2))
502 params.update(output)
503 params.setdefault("asyncResultsUrl", url)
504 params.setdefault("deploymentName", config["deploymentName"])
505 return (Status.PENDING, params)
506 else:
507 _LOG.error("Response: %s :: %s", response, response.text)
508 # _LOG.error("Bad Request:\n%s", response.request.body)
509 return (Status.FAILED, {})