Coverage for mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py: 100%

108 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for mlos_bench.services.remote.azure.azure_vm_services.""" 

6 

7from copy import deepcopy 

8from unittest.mock import MagicMock, patch 

9 

10import pytest 

11import requests.exceptions as requests_ex 

12 

13from mlos_bench.environments.status import Status 

14from mlos_bench.services.remote.azure.azure_auth import AzureAuthService 

15from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService 

16from mlos_bench.tests.services.remote.azure import make_httplib_json_response 

17 

18 

19@pytest.mark.parametrize( 

20 ("total_retries", "operation_status"), 

21 [ 

22 (2, Status.SUCCEEDED), 

23 (1, Status.FAILED), 

24 (0, Status.FAILED), 

25 ], 

26) 

27@patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") 

28def test_wait_host_deployment_retry( 

29 mock_getconn: MagicMock, 

30 total_retries: int, 

31 operation_status: Status, 

32 azure_vm_service: AzureVMService, 

33) -> None: 

34 """Test retries of the host deployment operation.""" 

35 # Simulate intermittent connection issues with multiple connection errors 

36 # Sufficient retry attempts should result in success, otherwise a graceful failure state 

37 mock_getconn.return_value.getresponse.side_effect = [ 

38 make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), 

39 requests_ex.ConnectionError( 

40 "Connection aborted", 

41 OSError(107, "Transport endpoint is not connected"), 

42 ), 

43 requests_ex.ConnectionError( 

44 "Connection aborted", 

45 OSError(107, "Transport endpoint is not connected"), 

46 ), 

47 make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), 

48 make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}), 

49 ] 

50 

51 (status, _) = azure_vm_service.wait_host_deployment( 

52 params={ 

53 "pollInterval": 0.1, 

54 "requestTotalRetries": total_retries, 

55 "deploymentName": "TEST_DEPLOYMENT1", 

56 "subscription": "TEST_SUB1", 

57 "resourceGroup": "TEST_RG1", 

58 }, 

59 is_setup=True, 

60 ) 

61 assert status == operation_status 

62 

63 

64def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAuthService) -> None: 

65 """Test expanding template params recursively.""" 

66 config = { 

67 "deploymentTemplatePath": ( 

68 "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc" 

69 ), 

70 "subscription": "TEST_SUB1", 

71 "resourceGroup": "TEST_RG1", 

72 "deploymentTemplateParameters": { 

73 "location": "$location", 

74 "vmMeta": "$vmName-$location", 

75 "vmNsg": "$vmMeta-nsg", 

76 }, 

77 } 

78 global_config = { 

79 "deploymentName": "TEST_DEPLOYMENT1", 

80 "vmName": "test-vm", 

81 "location": "eastus", 

82 } 

83 azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) 

84 assert azure_vm_service.deploy_params["location"] == global_config["location"] 

85 assert ( 

86 azure_vm_service.deploy_params["vmMeta"] 

87 == f'{global_config["vmName"]}-{global_config["location"]}' 

88 ) 

89 assert ( 

90 azure_vm_service.deploy_params["vmNsg"] 

91 == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg' 

92 ) 

93 

94 

95def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> None: 

96 """Test loading custom data from a file.""" 

97 config = { 

98 "customDataFile": "services/remote/azure/cloud-init/alt-ssh.yml", 

99 "deploymentTemplatePath": ( 

100 "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc" 

101 ), 

102 "subscription": "TEST_SUB1", 

103 "resourceGroup": "TEST_RG1", 

104 "deploymentTemplateParameters": { 

105 "location": "eastus2", 

106 }, 

107 } 

108 global_config = { 

109 "deploymentName": "TEST_DEPLOYMENT1", 

110 "vmName": "test-vm", 

111 } 

112 with pytest.raises(ValueError): 

113 config_with_custom_data = deepcopy(config) 

114 config_with_custom_data["deploymentTemplateParameters"]["customData"] = "DUMMY_CUSTOM_DATA" # type: ignore[index] # pylint: disable=line-too-long # noqa 

115 AzureVMService(config_with_custom_data, global_config, parent=azure_auth_service) 

116 azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) 

117 assert azure_vm_service.deploy_params["customData"] 

118 

119 

120@pytest.mark.parametrize( 

121 ("operation_name", "accepts_params"), 

122 [ 

123 ("start_host", True), 

124 ("stop_host", True), 

125 ("shutdown", True), 

126 ("deprovision_host", True), 

127 ("deallocate_host", True), 

128 ("restart_host", True), 

129 ("reboot", True), 

130 ], 

131) 

132@pytest.mark.parametrize( 

133 ("http_status_code", "operation_status"), 

134 [ 

135 (200, Status.SUCCEEDED), 

136 (202, Status.PENDING), 

137 (401, Status.FAILED), 

138 (404, Status.FAILED), 

139 ], 

140) 

141@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests") 

142# pylint: disable=too-many-arguments 

143def test_vm_operation_status( 

144 mock_requests: MagicMock, 

145 azure_vm_service: AzureVMService, 

146 operation_name: str, 

147 accepts_params: bool, 

148 http_status_code: int, 

149 operation_status: Status, 

150) -> None: 

151 """Test VM operation status.""" 

152 mock_response = MagicMock() 

153 mock_response.status_code = http_status_code 

154 mock_requests.post.return_value = mock_response 

155 

156 operation = getattr(azure_vm_service, operation_name) 

157 with pytest.raises(ValueError): 

158 # Missing vmName should raise ValueError 

159 (status, _) = operation({}) if accepts_params else operation() 

160 (status, _) = operation({"vmName": "test-vm"}) if accepts_params else operation() 

161 assert status == operation_status 

162 

163 

164@pytest.mark.parametrize( 

165 ("operation_name", "accepts_params"), 

166 [ 

167 ("provision_host", True), 

168 ], 

169) 

170def test_vm_operation_invalid( 

171 azure_vm_service_remote_exec_only: AzureVMService, 

172 operation_name: str, 

173 accepts_params: bool, 

174) -> None: 

175 """Test VM operation status for an incomplete service config.""" 

176 operation = getattr(azure_vm_service_remote_exec_only, operation_name) 

177 with pytest.raises(ValueError): 

178 (_, _) = operation({"vmName": "test-vm"}) if accepts_params else operation() 

179 

180 

181@patch("mlos_bench.services.remote.azure.azure_deployment_services.time.sleep") 

182@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") 

183def test_wait_vm_operation_ready( 

184 mock_session: MagicMock, 

185 mock_sleep: MagicMock, 

186 azure_vm_service: AzureVMService, 

187) -> None: 

188 """Test waiting for the completion of the remote VM operation.""" 

189 # Mock response header 

190 async_url = "DUMMY_ASYNC_URL" 

191 retry_after = 12345 

192 params = { 

193 "asyncResultsUrl": async_url, 

194 "vmName": "test-vm", 

195 "pollInterval": retry_after, 

196 } 

197 

198 mock_status_response = MagicMock(status_code=200) 

199 mock_status_response.json.return_value = { 

200 "status": "Succeeded", 

201 } 

202 mock_session.return_value.get.return_value = mock_status_response 

203 

204 status, _ = azure_vm_service.wait_host_operation(params) 

205 

206 assert (async_url,) == mock_session.return_value.get.call_args[0] 

207 assert (retry_after,) == mock_sleep.call_args[0] 

208 assert status.is_succeeded() 

209 

210 

211@patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") 

212def test_wait_vm_operation_timeout( 

213 mock_session: MagicMock, 

214 azure_vm_service: AzureVMService, 

215) -> None: 

216 """Test the time out of the remote VM operation.""" 

217 # Mock response header 

218 params = {"asyncResultsUrl": "DUMMY_ASYNC_URL", "vmName": "test-vm", "pollInterval": 1} 

219 

220 mock_status_response = MagicMock(status_code=200) 

221 mock_status_response.json.return_value = { 

222 "status": "InProgress", 

223 } 

224 mock_session.return_value.get.return_value = mock_status_response 

225 

226 (status, _) = azure_vm_service.wait_host_operation(params) 

227 assert status == Status.TIMED_OUT 

228 

229 

230@pytest.mark.parametrize( 

231 ("total_retries", "operation_status"), 

232 [ 

233 (2, Status.SUCCEEDED), 

234 (1, Status.FAILED), 

235 (0, Status.FAILED), 

236 ], 

237) 

238@patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") 

239def test_wait_vm_operation_retry( 

240 mock_getconn: MagicMock, 

241 total_retries: int, 

242 operation_status: Status, 

243 azure_vm_service: AzureVMService, 

244) -> None: 

245 """Test the retries of the remote VM operation.""" 

246 # Simulate intermittent connection issues with multiple connection errors 

247 # Sufficient retry attempts should result in success, otherwise a graceful failure state 

248 mock_getconn.return_value.getresponse.side_effect = [ 

249 make_httplib_json_response(200, {"status": "InProgress"}), 

250 requests_ex.ConnectionError( 

251 "Connection aborted", 

252 OSError(107, "Transport endpoint is not connected"), 

253 ), 

254 requests_ex.ConnectionError( 

255 "Connection aborted", 

256 OSError(107, "Transport endpoint is not connected"), 

257 ), 

258 make_httplib_json_response(200, {"status": "InProgress"}), 

259 make_httplib_json_response(200, {"status": "Succeeded"}), 

260 ] 

261 

262 (status, _) = azure_vm_service.wait_host_operation( 

263 params={ 

264 "pollInterval": 0.1, 

265 "requestTotalRetries": total_retries, 

266 "asyncResultsUrl": "https://DUMMY_ASYNC_URL", 

267 "vmName": "test-vm", 

268 } 

269 ) 

270 assert status == operation_status 

271 

272 

273@pytest.mark.parametrize( 

274 ("http_status_code", "operation_status"), 

275 [ 

276 (200, Status.SUCCEEDED), 

277 (202, Status.PENDING), 

278 (401, Status.FAILED), 

279 (404, Status.FAILED), 

280 ], 

281) 

282@patch("mlos_bench.services.remote.azure.azure_vm_services.requests") 

283def test_remote_exec_status( 

284 mock_requests: MagicMock, 

285 azure_vm_service_remote_exec_only: AzureVMService, 

286 http_status_code: int, 

287 operation_status: Status, 

288) -> None: 

289 """Test waiting for completion of the remote execution on Azure.""" 

290 script = ["command_1", "command_2"] 

291 

292 mock_response = MagicMock() 

293 mock_response.status_code = http_status_code 

294 mock_response.json = MagicMock( 

295 return_value={ 

296 "fake response": "body as json to dict", 

297 } 

298 ) 

299 mock_requests.post.return_value = mock_response 

300 

301 status, _ = azure_vm_service_remote_exec_only.remote_exec( 

302 script, 

303 config={"vmName": "test-vm"}, 

304 env_params={}, 

305 ) 

306 

307 assert status == operation_status 

308 

309 

310@patch("mlos_bench.services.remote.azure.azure_vm_services.requests") 

311def test_remote_exec_headers_output( 

312 mock_requests: MagicMock, 

313 azure_vm_service_remote_exec_only: AzureVMService, 

314) -> None: 

315 """Check if HTTP headers from the remote execution on Azure are correct.""" 

316 async_url_key = "asyncResultsUrl" 

317 async_url_value = "DUMMY_ASYNC_URL" 

318 script = ["command_1", "command_2"] 

319 

320 mock_response = MagicMock() 

321 mock_response.status_code = 202 

322 mock_response.headers = {"Azure-AsyncOperation": async_url_value} 

323 mock_response.json = MagicMock( 

324 return_value={ 

325 "fake response": "body as json to dict", 

326 } 

327 ) 

328 mock_requests.post.return_value = mock_response 

329 

330 _, cmd_output = azure_vm_service_remote_exec_only.remote_exec( 

331 script, 

332 config={"vmName": "test-vm"}, 

333 env_params={ 

334 "param_1": 123, 

335 "param_2": "abc", 

336 }, 

337 ) 

338 

339 assert async_url_key in cmd_output 

340 assert cmd_output[async_url_key] == async_url_value 

341 

342 assert mock_requests.post.call_args[1]["json"] == { 

343 "commandId": "RunShellScript", 

344 "script": script, 

345 "parameters": [{"name": "param_1", "value": 123}, {"name": "param_2", "value": "abc"}], 

346 } 

347 

348 

349@pytest.mark.parametrize( 

350 ("operation_status", "wait_output", "results_output"), 

351 [ 

352 ( 

353 Status.SUCCEEDED, 

354 { 

355 "properties": { 

356 "output": { 

357 "value": [ 

358 {"message": "DUMMY_STDOUT_STDERR"}, 

359 ] 

360 } 

361 } 

362 }, 

363 {"stdout": "DUMMY_STDOUT_STDERR"}, 

364 ), 

365 (Status.PENDING, {}, {}), 

366 (Status.FAILED, {}, {}), 

367 ], 

368) 

369def test_get_remote_exec_results( 

370 azure_vm_service_remote_exec_only: AzureVMService, 

371 operation_status: Status, 

372 wait_output: dict, 

373 results_output: dict, 

374) -> None: 

375 """Test getting the results of the remote execution on Azure.""" 

376 params = {"asyncResultsUrl": "DUMMY_ASYNC_URL"} 

377 

378 mock_wait_host_operation = MagicMock() 

379 mock_wait_host_operation.return_value = (operation_status, wait_output) 

380 # azure_vm_service.wait_host_operation = mock_wait_host_operation 

381 setattr(azure_vm_service_remote_exec_only, "wait_host_operation", mock_wait_host_operation) 

382 

383 status, cmd_output = azure_vm_service_remote_exec_only.get_remote_exec_results(params) 

384 

385 assert status == operation_status 

386 assert mock_wait_host_operation.call_args[0][0] == params 

387 assert cmd_output == results_output