Coverage for mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py: 100%
108 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Tests for mlos_bench.services.remote.azure.azure_vm_services."""
7from copy import deepcopy
8from unittest.mock import MagicMock, patch
10import pytest
11import requests.exceptions as requests_ex
13from mlos_bench.environments.status import Status
14from mlos_bench.services.remote.azure.azure_auth import AzureAuthService
15from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService
16from mlos_bench.tests.services.remote.azure import make_httplib_json_response
19@pytest.mark.parametrize(
20 ("total_retries", "operation_status"),
21 [
22 (2, Status.SUCCEEDED),
23 (1, Status.FAILED),
24 (0, Status.FAILED),
25 ],
26)
27@patch("urllib3.connectionpool.HTTPConnectionPool._get_conn")
28def test_wait_host_deployment_retry(
29 mock_getconn: MagicMock,
30 total_retries: int,
31 operation_status: Status,
32 azure_vm_service: AzureVMService,
33) -> None:
34 """Test retries of the host deployment operation."""
35 # Simulate intermittent connection issues with multiple connection errors
36 # Sufficient retry attempts should result in success, otherwise a graceful failure state
37 mock_getconn.return_value.getresponse.side_effect = [
38 make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}),
39 requests_ex.ConnectionError(
40 "Connection aborted",
41 OSError(107, "Transport endpoint is not connected"),
42 ),
43 requests_ex.ConnectionError(
44 "Connection aborted",
45 OSError(107, "Transport endpoint is not connected"),
46 ),
47 make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}),
48 make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}),
49 ]
51 (status, _) = azure_vm_service.wait_host_deployment(
52 params={
53 "pollInterval": 0.1,
54 "requestTotalRetries": total_retries,
55 "deploymentName": "TEST_DEPLOYMENT1",
56 "subscription": "TEST_SUB1",
57 "resourceGroup": "TEST_RG1",
58 },
59 is_setup=True,
60 )
61 assert status == operation_status
64def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAuthService) -> None:
65 """Test expanding template params recursively."""
66 config = {
67 "deploymentTemplatePath": (
68 "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc"
69 ),
70 "subscription": "TEST_SUB1",
71 "resourceGroup": "TEST_RG1",
72 "deploymentTemplateParameters": {
73 "location": "$location",
74 "vmMeta": "$vmName-$location",
75 "vmNsg": "$vmMeta-nsg",
76 },
77 }
78 global_config = {
79 "deploymentName": "TEST_DEPLOYMENT1",
80 "vmName": "test-vm",
81 "location": "eastus",
82 }
83 azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service)
84 assert azure_vm_service.deploy_params["location"] == global_config["location"]
85 assert (
86 azure_vm_service.deploy_params["vmMeta"]
87 == f'{global_config["vmName"]}-{global_config["location"]}'
88 )
89 assert (
90 azure_vm_service.deploy_params["vmNsg"]
91 == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg'
92 )
95def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> None:
96 """Test loading custom data from a file."""
97 config = {
98 "customDataFile": "services/remote/azure/cloud-init/alt-ssh.yml",
99 "deploymentTemplatePath": (
100 "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc"
101 ),
102 "subscription": "TEST_SUB1",
103 "resourceGroup": "TEST_RG1",
104 "deploymentTemplateParameters": {
105 "location": "eastus2",
106 },
107 }
108 global_config = {
109 "deploymentName": "TEST_DEPLOYMENT1",
110 "vmName": "test-vm",
111 }
112 with pytest.raises(ValueError):
113 config_with_custom_data = deepcopy(config)
114 config_with_custom_data["deploymentTemplateParameters"]["customData"] = "DUMMY_CUSTOM_DATA" # type: ignore[index] # pylint: disable=line-too-long # noqa
115 AzureVMService(config_with_custom_data, global_config, parent=azure_auth_service)
116 azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service)
117 assert azure_vm_service.deploy_params["customData"]
120@pytest.mark.parametrize(
121 ("operation_name", "accepts_params"),
122 [
123 ("start_host", True),
124 ("stop_host", True),
125 ("shutdown", True),
126 ("deprovision_host", True),
127 ("deallocate_host", True),
128 ("restart_host", True),
129 ("reboot", True),
130 ],
131)
132@pytest.mark.parametrize(
133 ("http_status_code", "operation_status"),
134 [
135 (200, Status.SUCCEEDED),
136 (202, Status.PENDING),
137 (401, Status.FAILED),
138 (404, Status.FAILED),
139 ],
140)
141@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests")
142# pylint: disable=too-many-arguments
143def test_vm_operation_status(
144 mock_requests: MagicMock,
145 azure_vm_service: AzureVMService,
146 operation_name: str,
147 accepts_params: bool,
148 http_status_code: int,
149 operation_status: Status,
150) -> None:
151 """Test VM operation status."""
152 mock_response = MagicMock()
153 mock_response.status_code = http_status_code
154 mock_requests.post.return_value = mock_response
156 operation = getattr(azure_vm_service, operation_name)
157 with pytest.raises(ValueError):
158 # Missing vmName should raise ValueError
159 (status, _) = operation({}) if accepts_params else operation()
160 (status, _) = operation({"vmName": "test-vm"}) if accepts_params else operation()
161 assert status == operation_status
164@pytest.mark.parametrize(
165 ("operation_name", "accepts_params"),
166 [
167 ("provision_host", True),
168 ],
169)
170def test_vm_operation_invalid(
171 azure_vm_service_remote_exec_only: AzureVMService,
172 operation_name: str,
173 accepts_params: bool,
174) -> None:
175 """Test VM operation status for an incomplete service config."""
176 operation = getattr(azure_vm_service_remote_exec_only, operation_name)
177 with pytest.raises(ValueError):
178 (_, _) = operation({"vmName": "test-vm"}) if accepts_params else operation()
181@patch("mlos_bench.services.remote.azure.azure_deployment_services.time.sleep")
182@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session")
183def test_wait_vm_operation_ready(
184 mock_session: MagicMock,
185 mock_sleep: MagicMock,
186 azure_vm_service: AzureVMService,
187) -> None:
188 """Test waiting for the completion of the remote VM operation."""
189 # Mock response header
190 async_url = "DUMMY_ASYNC_URL"
191 retry_after = 12345
192 params = {
193 "asyncResultsUrl": async_url,
194 "vmName": "test-vm",
195 "pollInterval": retry_after,
196 }
198 mock_status_response = MagicMock(status_code=200)
199 mock_status_response.json.return_value = {
200 "status": "Succeeded",
201 }
202 mock_session.return_value.get.return_value = mock_status_response
204 status, _ = azure_vm_service.wait_host_operation(params)
206 assert (async_url,) == mock_session.return_value.get.call_args[0]
207 assert (retry_after,) == mock_sleep.call_args[0]
208 assert status.is_succeeded()
211@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session")
212def test_wait_vm_operation_timeout(
213 mock_session: MagicMock,
214 azure_vm_service: AzureVMService,
215) -> None:
216 """Test the time out of the remote VM operation."""
217 # Mock response header
218 params = {"asyncResultsUrl": "DUMMY_ASYNC_URL", "vmName": "test-vm", "pollInterval": 1}
220 mock_status_response = MagicMock(status_code=200)
221 mock_status_response.json.return_value = {
222 "status": "InProgress",
223 }
224 mock_session.return_value.get.return_value = mock_status_response
226 (status, _) = azure_vm_service.wait_host_operation(params)
227 assert status == Status.TIMED_OUT
230@pytest.mark.parametrize(
231 ("total_retries", "operation_status"),
232 [
233 (2, Status.SUCCEEDED),
234 (1, Status.FAILED),
235 (0, Status.FAILED),
236 ],
237)
238@patch("urllib3.connectionpool.HTTPConnectionPool._get_conn")
239def test_wait_vm_operation_retry(
240 mock_getconn: MagicMock,
241 total_retries: int,
242 operation_status: Status,
243 azure_vm_service: AzureVMService,
244) -> None:
245 """Test the retries of the remote VM operation."""
246 # Simulate intermittent connection issues with multiple connection errors
247 # Sufficient retry attempts should result in success, otherwise a graceful failure state
248 mock_getconn.return_value.getresponse.side_effect = [
249 make_httplib_json_response(200, {"status": "InProgress"}),
250 requests_ex.ConnectionError(
251 "Connection aborted",
252 OSError(107, "Transport endpoint is not connected"),
253 ),
254 requests_ex.ConnectionError(
255 "Connection aborted",
256 OSError(107, "Transport endpoint is not connected"),
257 ),
258 make_httplib_json_response(200, {"status": "InProgress"}),
259 make_httplib_json_response(200, {"status": "Succeeded"}),
260 ]
262 (status, _) = azure_vm_service.wait_host_operation(
263 params={
264 "pollInterval": 0.1,
265 "requestTotalRetries": total_retries,
266 "asyncResultsUrl": "https://DUMMY_ASYNC_URL",
267 "vmName": "test-vm",
268 }
269 )
270 assert status == operation_status
273@pytest.mark.parametrize(
274 ("http_status_code", "operation_status"),
275 [
276 (200, Status.SUCCEEDED),
277 (202, Status.PENDING),
278 (401, Status.FAILED),
279 (404, Status.FAILED),
280 ],
281)
282@patch("mlos_bench.services.remote.azure.azure_vm_services.requests")
283def test_remote_exec_status(
284 mock_requests: MagicMock,
285 azure_vm_service_remote_exec_only: AzureVMService,
286 http_status_code: int,
287 operation_status: Status,
288) -> None:
289 """Test waiting for completion of the remote execution on Azure."""
290 script = ["command_1", "command_2"]
292 mock_response = MagicMock()
293 mock_response.status_code = http_status_code
294 mock_response.json = MagicMock(
295 return_value={
296 "fake response": "body as json to dict",
297 }
298 )
299 mock_requests.post.return_value = mock_response
301 status, _ = azure_vm_service_remote_exec_only.remote_exec(
302 script,
303 config={"vmName": "test-vm"},
304 env_params={},
305 )
307 assert status == operation_status
310@patch("mlos_bench.services.remote.azure.azure_vm_services.requests")
311def test_remote_exec_headers_output(
312 mock_requests: MagicMock,
313 azure_vm_service_remote_exec_only: AzureVMService,
314) -> None:
315 """Check if HTTP headers from the remote execution on Azure are correct."""
316 async_url_key = "asyncResultsUrl"
317 async_url_value = "DUMMY_ASYNC_URL"
318 script = ["command_1", "command_2"]
320 mock_response = MagicMock()
321 mock_response.status_code = 202
322 mock_response.headers = {"Azure-AsyncOperation": async_url_value}
323 mock_response.json = MagicMock(
324 return_value={
325 "fake response": "body as json to dict",
326 }
327 )
328 mock_requests.post.return_value = mock_response
330 _, cmd_output = azure_vm_service_remote_exec_only.remote_exec(
331 script,
332 config={"vmName": "test-vm"},
333 env_params={
334 "param_1": 123,
335 "param_2": "abc",
336 },
337 )
339 assert async_url_key in cmd_output
340 assert cmd_output[async_url_key] == async_url_value
342 assert mock_requests.post.call_args[1]["json"] == {
343 "commandId": "RunShellScript",
344 "script": script,
345 "parameters": [{"name": "param_1", "value": 123}, {"name": "param_2", "value": "abc"}],
346 }
349@pytest.mark.parametrize(
350 ("operation_status", "wait_output", "results_output"),
351 [
352 (
353 Status.SUCCEEDED,
354 {
355 "properties": {
356 "output": {
357 "value": [
358 {"message": "DUMMY_STDOUT_STDERR"},
359 ]
360 }
361 }
362 },
363 {"stdout": "DUMMY_STDOUT_STDERR"},
364 ),
365 (Status.PENDING, {}, {}),
366 (Status.FAILED, {}, {}),
367 ],
368)
369def test_get_remote_exec_results(
370 azure_vm_service_remote_exec_only: AzureVMService,
371 operation_status: Status,
372 wait_output: dict,
373 results_output: dict,
374) -> None:
375 """Test getting the results of the remote execution on Azure."""
376 params = {"asyncResultsUrl": "DUMMY_ASYNC_URL"}
378 mock_wait_host_operation = MagicMock()
379 mock_wait_host_operation.return_value = (operation_status, wait_output)
380 # azure_vm_service.wait_host_operation = mock_wait_host_operation
381 setattr(azure_vm_service_remote_exec_only, "wait_host_operation", mock_wait_host_operation)
383 status, cmd_output = azure_vm_service_remote_exec_only.get_remote_exec_results(params)
385 assert status == operation_status
386 assert mock_wait_host_operation.call_args[0][0] == params
387 assert cmd_output == results_output