Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_saas.py: 46%
82 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""A collection Service functions for configuring SaaS instances on Azure."""
6import logging
7from collections.abc import Callable
8from typing import Any
10import requests
12from mlos_bench.environments.status import Status
13from mlos_bench.services.base_service import Service
14from mlos_bench.services.types.authenticator_type import SupportsAuth
15from mlos_bench.services.types.remote_config_type import SupportsRemoteConfig
16from mlos_bench.util import check_required_params, merge_parameters
18_LOG = logging.getLogger(__name__)
21class AzureSaaSConfigService(Service, SupportsRemoteConfig):
22 """Helper methods to configure Azure Flex services."""
24 _REQUEST_TIMEOUT = 5 # seconds
26 # REST API for Azure SaaS DB Services configuration as described in:
27 # https://learn.microsoft.com/en-us/rest/api/mysql/flexibleserver/configurations
28 # https://learn.microsoft.com/en-us/rest/api/postgresql/flexibleserver/configurations
29 # https://learn.microsoft.com/en-us/rest/api/mariadb/configurations
31 _URL_CONFIGURE = (
32 "https://management.azure.com"
33 "/subscriptions/{subscription}"
34 "/resourceGroups/{resource_group}"
35 "/providers/{provider}"
36 "/{server_type}/{vm_name}"
37 "/{update}"
38 "?api-version={api_version}"
39 )
41 def __init__(
42 self,
43 config: dict[str, Any] | None = None,
44 global_config: dict[str, Any] | None = None,
45 parent: Service | None = None,
46 methods: dict[str, Callable] | list[Callable] | None = None,
47 ):
48 """
49 Create a new instance of Azure services proxy.
51 Parameters
52 ----------
53 config : dict
54 Free-format dictionary that contains the benchmark environment
55 configuration.
56 global_config : dict
57 Free-format dictionary of global parameters.
58 parent : Service
59 Parent service that can provide mixin functions.
60 methods : Union[dict[str, Callable], list[Callable], None]
61 New methods to register with the service.
62 """
63 super().__init__(
64 config,
65 global_config,
66 parent,
67 self.merge_methods(methods, [self.configure, self.is_config_pending]),
68 )
70 check_required_params(
71 self.config,
72 {
73 "subscription",
74 "resourceGroup",
75 "provider",
76 },
77 )
79 # Provide sane defaults for known DB providers.
80 provider = self.config.get("provider")
81 if provider == "Microsoft.DBforMySQL":
82 self._is_batch = self.config.get("supportsBatchUpdate", True)
83 is_flex = self.config.get("isFlex", True)
84 api_version = self.config.get("apiVersion", "2022-01-01")
85 elif provider == "Microsoft.DBforMariaDB":
86 self._is_batch = self.config.get("supportsBatchUpdate", False)
87 is_flex = self.config.get("isFlex", False)
88 api_version = self.config.get("apiVersion", "2018-06-01")
89 elif provider == "Microsoft.DBforPostgreSQL":
90 self._is_batch = self.config.get("supportsBatchUpdate", False)
91 is_flex = self.config.get("isFlex", True)
92 api_version = self.config.get("apiVersion", "2022-12-01")
93 else:
94 self._is_batch = self.config["supportsBatchUpdate"]
95 is_flex = self.config["isFlex"]
96 api_version = self.config["apiVersion"]
98 self._url_config_set = self._URL_CONFIGURE.format(
99 subscription=self.config["subscription"],
100 resource_group=self.config["resourceGroup"],
101 provider=self.config["provider"],
102 vm_name="{vm_name}",
103 server_type="flexibleServers" if is_flex else "servers",
104 update="updateConfigurations" if self._is_batch else "configurations/{param_name}",
105 api_version=api_version,
106 )
108 self._url_config_get = self._URL_CONFIGURE.format(
109 subscription=self.config["subscription"],
110 resource_group=self.config["resourceGroup"],
111 provider=self.config["provider"],
112 vm_name="{vm_name}",
113 server_type="flexibleServers" if is_flex else "servers",
114 update="configurations",
115 api_version=api_version,
116 )
118 # These parameters can come from command line as strings, so conversion is needed.
119 self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT))
121 def configure(self, config: dict[str, Any], params: dict[str, Any]) -> tuple[Status, dict]:
122 """
123 Update the parameters of an Azure DB service.
125 Parameters
126 ----------
127 config : dict[str, Any]
128 Key/value pairs of configuration parameters (e.g., vmName).
129 params : dict[str, Any]
130 Key/value pairs of the service parameters to update.
132 Returns
133 -------
134 result : (Status, dict)
135 A pair of Status and result. The result is always {}.
136 Status is one of {PENDING, SUCCEEDED, FAILED}
137 """
138 if self._is_batch:
139 return self._config_batch(config, params)
140 return self._config_many(config, params)
142 def is_config_pending(self, config: dict[str, Any]) -> tuple[Status, dict]:
143 """
144 Check if the configuration of an Azure DB service requires a reboot or restart.
146 Parameters
147 ----------
148 config : dict[str, Any]
149 Key/value pairs of configuration parameters (e.g., vmName).
151 Returns
152 -------
153 result : (Status, dict)
154 A pair of Status and result. A Boolean field
155 "isConfigPendingRestart" indicates whether the service restart is required.
156 If "isConfigPendingReboot" is set to True, rebooting a VM is necessary.
157 Status is one of {PENDING, TIMED_OUT, SUCCEEDED, FAILED}
158 """
159 config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"])
160 url = self._url_config_get.format(vm_name=config["vmName"])
161 _LOG.debug("Request: GET %s", url)
162 response = requests.put(url, headers=self._get_headers(), timeout=self._request_timeout)
163 _LOG.debug("Response: %s :: %s", response, response.text)
164 if response.status_code == 504:
165 return (Status.TIMED_OUT, {})
166 if response.status_code != 200:
167 return (Status.FAILED, {})
168 # Currently, Azure Flex servers require a VM reboot.
169 return (
170 Status.SUCCEEDED,
171 {
172 "isConfigPendingReboot": any(
173 {"False": False, "True": True}[val["properties"]["isConfigPendingRestart"]]
174 for val in response.json()["value"]
175 )
176 },
177 )
179 def _get_headers(self) -> dict:
180 """Get the headers for the REST API calls."""
181 assert self._parent is not None and isinstance(
182 self._parent, SupportsAuth
183 ), "Authorization service not provided. Include service-auth.jsonc?"
184 return self._parent.get_auth_headers()
186 def _config_one(
187 self,
188 config: dict[str, Any],
189 param_name: str,
190 param_value: Any,
191 ) -> tuple[Status, dict]:
192 """
193 Update a single parameter of the Azure DB service.
195 Parameters
196 ----------
197 config : dict[str, Any]
198 Key/value pairs of configuration parameters (e.g., vmName).
199 param_name : str
200 Name of the parameter to update.
201 param_value : Any
202 Value of the parameter to update.
204 Returns
205 -------
206 result : (Status, dict)
207 A pair of Status and result. The result is always {}.
208 Status is one of {PENDING, SUCCEEDED, FAILED}
209 """
210 config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"])
211 url = self._url_config_set.format(vm_name=config["vmName"], param_name=param_name)
212 _LOG.debug("Request: PUT %s", url)
213 response = requests.put(
214 url,
215 headers=self._get_headers(),
216 json={"properties": {"value": str(param_value)}},
217 timeout=self._request_timeout,
218 )
219 _LOG.debug("Response: %s :: %s", response, response.text)
220 if response.status_code == 504:
221 return (Status.TIMED_OUT, {})
222 if response.status_code == 200:
223 return (Status.SUCCEEDED, {})
224 return (Status.FAILED, {})
226 def _config_many(self, config: dict[str, Any], params: dict[str, Any]) -> tuple[Status, dict]:
227 """
228 Update the parameters of an Azure DB service one-by-one. (If batch API is not
229 available for it).
231 Parameters
232 ----------
233 config : dict[str, Any]
234 Key/value pairs of configuration parameters (e.g., vmName).
235 params : dict[str, Any]
236 Key/value pairs of the service parameters to update.
238 Returns
239 -------
240 result : (Status, dict)
241 A pair of Status and result. The result is always {}.
242 Status is one of {PENDING, SUCCEEDED, FAILED}
243 """
244 for param_name, param_value in params.items():
245 (status, result) = self._config_one(config, param_name, param_value)
246 if not status.is_succeeded():
247 return (status, result)
248 return (Status.SUCCEEDED, {})
250 def _config_batch(self, config: dict[str, Any], params: dict[str, Any]) -> tuple[Status, dict]:
251 """
252 Batch update the parameters of an Azure DB service.
254 Parameters
255 ----------
256 config : dict[str, Any]
257 Key/value pairs of configuration parameters (e.g., vmName).
258 params : dict[str, Any]
259 Key/value pairs of the service parameters to update.
261 Returns
262 -------
263 result : (Status, dict)
264 A pair of Status and result. The result is always {}.
265 Status is one of {PENDING, SUCCEEDED, FAILED}
266 """
267 config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"])
268 url = self._url_config_set.format(vm_name=config["vmName"])
269 json_req = {
270 "value": [
271 {"name": key, "properties": {"value": str(val)}} for (key, val) in params.items()
272 ],
273 # "resetAllToDefault": "True"
274 }
275 _LOG.debug("Request: POST %s", url)
276 response = requests.post(
277 url,
278 headers=self._get_headers(),
279 json=json_req,
280 timeout=self._request_timeout,
281 )
282 _LOG.debug("Response: %s :: %s", response, response.text)
283 if response.status_code == 504:
284 return (Status.TIMED_OUT, {})
285 if response.status_code == 200:
286 return (Status.SUCCEEDED, {})
287 return (Status.FAILED, {})