Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py: 80%

137 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A collection Service functions for managing VMs on Azure.""" 

6 

7import json 

8import logging 

9from collections.abc import Callable, Iterable 

10from datetime import datetime 

11from typing import Any 

12 

13import requests 

14 

15from mlos_bench.environments.status import Status 

16from mlos_bench.services.base_service import Service 

17from mlos_bench.services.remote.azure.azure_deployment_services import ( 

18 AzureDeploymentService, 

19) 

20from mlos_bench.services.types.host_ops_type import SupportsHostOps 

21from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning 

22from mlos_bench.services.types.os_ops_type import SupportsOSOps 

23from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec 

24from mlos_bench.util import merge_parameters 

25 

26_LOG = logging.getLogger(__name__) 

27 

28 

29class AzureVMService( 

30 AzureDeploymentService, 

31 SupportsHostProvisioning, 

32 SupportsHostOps, 

33 SupportsOSOps, 

34 SupportsRemoteExec, 

35): 

36 """Helper methods to manage VMs on Azure.""" 

37 

38 # pylint: disable=too-many-ancestors 

39 

40 # Azure Compute REST API calls as described in 

41 # https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines 

42 

43 # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/start 

44 _URL_START = ( 

45 "https://management.azure.com" 

46 "/subscriptions/{subscription}" 

47 "/resourceGroups/{resource_group}" 

48 "/providers/Microsoft.Compute" 

49 "/virtualMachines/{vm_name}" 

50 "/start" 

51 "?api-version=2022-03-01" 

52 ) 

53 

54 # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/power-off 

55 _URL_STOP = ( 

56 "https://management.azure.com" 

57 "/subscriptions/{subscription}" 

58 "/resourceGroups/{resource_group}" 

59 "/providers/Microsoft.Compute" 

60 "/virtualMachines/{vm_name}" 

61 "/powerOff" 

62 "?api-version=2022-03-01" 

63 ) 

64 

65 # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/deallocate 

66 _URL_DEALLOCATE = ( 

67 "https://management.azure.com" 

68 "/subscriptions/{subscription}" 

69 "/resourceGroups/{resource_group}" 

70 "/providers/Microsoft.Compute" 

71 "/virtualMachines/{vm_name}" 

72 "/deallocate" 

73 "?api-version=2022-03-01" 

74 ) 

75 

76 # TODO: This is probably the more correct URL to use for the deprovision operation. 

77 # However, previous code used the deallocate URL above, so for now, we keep 

78 # that and handle that change later. 

79 # See Also: #498 

80 _URL_DEPROVISION = _URL_DEALLOCATE 

81 

82 # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/delete 

83 # _URL_DEPROVISION = ( 

84 # "https://management.azure.com" 

85 # "/subscriptions/{subscription}" 

86 # "/resourceGroups/{resource_group}" 

87 # "/providers/Microsoft.Compute" 

88 # "/virtualMachines/{vm_name}" 

89 # "/delete" 

90 # "?api-version=2022-03-01" 

91 # ) 

92 

93 # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/restart 

94 _URL_REBOOT = ( 

95 "https://management.azure.com" 

96 "/subscriptions/{subscription}" 

97 "/resourceGroups/{resource_group}" 

98 "/providers/Microsoft.Compute" 

99 "/virtualMachines/{vm_name}" 

100 "/restart" 

101 "?api-version=2022-03-01" 

102 ) 

103 

104 # From: 

105 # https://learn.microsoft.com/en-us/rest/api/compute/virtual-machine-run-commands/create-or-update 

106 _URL_REXEC_RUN = ( 

107 "https://management.azure.com" 

108 "/subscriptions/{subscription}" 

109 "/resourceGroups/{resource_group}" 

110 "/providers/Microsoft.Compute" 

111 "/virtualMachines/{vm_name}" 

112 "/runcommands/{command_name}" 

113 "?api-version=2024-07-01" 

114 ) 

115 _URL_REXEC_RESULT = ( 

116 "https://management.azure.com" 

117 "/subscriptions/{subscription}" 

118 "/resourceGroups/{resource_group}" 

119 "/providers/Microsoft.Compute" 

120 "/virtualMachines/{vm_name}" 

121 "/runcommands/{command_name}" 

122 "?$expand=instanceView&api-version=2024-07-01" 

123 ) 

124 

125 def __init__( 

126 self, 

127 config: dict[str, Any] | None = None, 

128 global_config: dict[str, Any] | None = None, 

129 parent: Service | None = None, 

130 methods: dict[str, Callable] | list[Callable] | None = None, 

131 ): 

132 """ 

133 Create a new instance of Azure VM services proxy. 

134 

135 Parameters 

136 ---------- 

137 config : dict 

138 Free-format dictionary that contains the benchmark environment 

139 configuration. 

140 global_config : dict 

141 Free-format dictionary of global parameters. 

142 parent : Service 

143 Parent service that can provide mixin functions. 

144 methods : Union[dict[str, Callable], list[Callable], None] 

145 New methods to register with the service. 

146 """ 

147 super().__init__( 

148 config, 

149 global_config, 

150 parent, 

151 self.merge_methods( 

152 methods, 

153 [ 

154 # SupportsHostProvisioning 

155 self.provision_host, 

156 self.deprovision_host, 

157 self.deallocate_host, 

158 self.wait_host_deployment, 

159 # SupportsHostOps 

160 self.start_host, 

161 self.stop_host, 

162 self.restart_host, 

163 self.wait_host_operation, 

164 # SupportsOSOps 

165 self.shutdown, 

166 self.reboot, 

167 self.wait_os_operation, 

168 # SupportsRemoteExec 

169 self.remote_exec, 

170 self.get_remote_exec_results, 

171 ], 

172 ), 

173 ) 

174 

175 # As a convenience, allow reading customData out of a file, rather than 

176 # embedding it in a json config file. 

177 # Note: ARM templates expect this data to be base64 encoded, but that 

178 # can be done using the `base64()` string function inside the ARM template. 

179 self._custom_data_file = self.config.get("customDataFile", None) 

180 if self._custom_data_file: 

181 if self._deploy_params.get("customData", None): 

182 raise ValueError("Both customDataFile and customData are specified.") 

183 self._custom_data_file = self.config_loader_service.resolve_path( 

184 self._custom_data_file 

185 ) 

186 with open(self._custom_data_file, encoding="utf-8") as custom_data_fh: 

187 self._deploy_params["customData"] = custom_data_fh.read() 

188 

189 def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use 

190 # Try and provide a semi sane default for the deploymentName if not provided 

191 # since this is a common way to set the deploymentName and can same some 

192 # config work for the caller. 

193 if "vmName" in params and "deploymentName" not in params: 

194 params["deploymentName"] = f"""{params["vmName"]}-deployment""" 

195 

196 _LOG.info( 

197 "deploymentName missing from params. Defaulting to '%s'.", 

198 params["deploymentName"], 

199 ) 

200 return params 

201 

202 def wait_host_deployment(self, params: dict, *, is_setup: bool) -> tuple[Status, dict]: 

203 """ 

204 Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED. 

205 Return TIMED_OUT when timing out. 

206 

207 Parameters 

208 ---------- 

209 params : dict 

210 Flat dictionary of (key, value) pairs of tunable parameters. 

211 is_setup : bool 

212 If True, wait for VM being deployed; otherwise, wait for successful deprovisioning. 

213 

214 Returns 

215 ------- 

216 result : (Status, dict) 

217 A pair of Status and result. 

218 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

219 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

220 """ 

221 return self._wait_deployment(params, is_setup=is_setup) 

222 

223 def wait_host_operation(self, params: dict) -> tuple[Status, dict]: 

224 """ 

225 Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED. 

226 Return TIMED_OUT when timing out. 

227 

228 Parameters 

229 ---------- 

230 params: dict 

231 Flat dictionary of (key, value) pairs of tunable parameters. 

232 Must have the "asyncResultsUrl" key to get the results. 

233 If the key is not present, return Status.PENDING. 

234 

235 Returns 

236 ------- 

237 result : (Status, dict) 

238 A pair of Status and result. 

239 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

240 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

241 """ 

242 _LOG.info("Wait for operation on VM %s", params["vmName"]) 

243 # Try and provide a semi sane default for the deploymentName 

244 params.setdefault(f"""{params["vmName"]}-deployment""") 

245 return self._wait_while(self._check_operation_status, Status.RUNNING, params) 

246 

247 def wait_remote_exec_operation(self, params: dict) -> tuple["Status", dict]: 

248 """ 

249 Waits for a pending remote execution on an Azure VM to resolve to SUCCEEDED or 

250 FAILED. Return TIMED_OUT when timing out. 

251 

252 Parameters 

253 ---------- 

254 params: dict 

255 Flat dictionary of (key, value) pairs of tunable parameters. 

256 Must have the "asyncResultsUrl" key to get the results. 

257 If the key is not present, return Status.PENDING. 

258 

259 Returns 

260 ------- 

261 result : (Status, dict) 

262 A pair of Status and result. 

263 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

264 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

265 """ 

266 _LOG.info("Wait for run command %s on VM %s", params["commandName"], params["vmName"]) 

267 return self._wait_while(self._check_remote_exec_status, Status.RUNNING, params) 

268 

269 def wait_os_operation(self, params: dict) -> tuple["Status", dict]: 

270 return self.wait_host_operation(params) 

271 

272 def provision_host(self, params: dict) -> tuple[Status, dict]: 

273 """ 

274 Check if Azure VM is ready. Deploy a new VM, if necessary. 

275 

276 Parameters 

277 ---------- 

278 params : dict 

279 Flat dictionary of (key, value) pairs of tunable parameters. 

280 HostEnv tunables are variable parameters that, together with the 

281 HostEnv configuration, are sufficient to provision a VM. 

282 

283 Returns 

284 ------- 

285 result : (Status, dict) 

286 A pair of Status and result. The result is the input `params` plus the 

287 parameters extracted from the response JSON, or {} if the status is FAILED. 

288 Status is one of {PENDING, SUCCEEDED, FAILED} 

289 """ 

290 return self._provision_resource(params) 

291 

292 def deprovision_host(self, params: dict) -> tuple[Status, dict]: 

293 """ 

294 Deprovisions the VM on Azure by deleting it. 

295 

296 Parameters 

297 ---------- 

298 params : dict 

299 Flat dictionary of (key, value) pairs of tunable parameters. 

300 

301 Returns 

302 ------- 

303 result : (Status, dict) 

304 A pair of Status and result. The result is always {}. 

305 Status is one of {PENDING, SUCCEEDED, FAILED} 

306 """ 

307 params = self._set_default_params(params) 

308 config = merge_parameters( 

309 dest=self.config.copy(), 

310 source=params, 

311 required_keys=[ 

312 "subscription", 

313 "resourceGroup", 

314 "deploymentName", 

315 "vmName", 

316 ], 

317 ) 

318 _LOG.info("Deprovision VM: %s", config["vmName"]) 

319 _LOG.info("Deprovision deployment: %s", config["deploymentName"]) 

320 # TODO: Properly deprovision *all* resources specified in the ARM template. 

321 return self._azure_rest_api_post_helper( 

322 config, 

323 self._URL_DEPROVISION.format( 

324 subscription=config["subscription"], 

325 resource_group=config["resourceGroup"], 

326 vm_name=config["vmName"], 

327 ), 

328 ) 

329 

330 def deallocate_host(self, params: dict) -> tuple[Status, dict]: 

331 """ 

332 Deallocates the VM on Azure by shutting it down then releasing the compute 

333 resources. 

334 

335 Note: This can cause the VM to arrive on a new host node when its 

336 restarted, which may have different performance characteristics. 

337 

338 Parameters 

339 ---------- 

340 params : dict 

341 Flat dictionary of (key, value) pairs of tunable parameters. 

342 

343 Returns 

344 ------- 

345 result : (Status, dict) 

346 A pair of Status and result. The result is always {}. 

347 Status is one of {PENDING, SUCCEEDED, FAILED} 

348 """ 

349 params = self._set_default_params(params) 

350 config = merge_parameters( 

351 dest=self.config.copy(), 

352 source=params, 

353 required_keys=[ 

354 "subscription", 

355 "resourceGroup", 

356 "vmName", 

357 ], 

358 ) 

359 _LOG.info("Deallocate VM: %s", config["vmName"]) 

360 return self._azure_rest_api_post_helper( 

361 config, 

362 self._URL_DEALLOCATE.format( 

363 subscription=config["subscription"], 

364 resource_group=config["resourceGroup"], 

365 vm_name=config["vmName"], 

366 ), 

367 ) 

368 

369 def start_host(self, params: dict) -> tuple[Status, dict]: 

370 """ 

371 Start the VM on Azure. 

372 

373 Parameters 

374 ---------- 

375 params : dict 

376 Flat dictionary of (key, value) pairs of tunable parameters. 

377 

378 Returns 

379 ------- 

380 result : (Status, dict) 

381 A pair of Status and result. The result is always {}. 

382 Status is one of {PENDING, SUCCEEDED, FAILED} 

383 """ 

384 params = self._set_default_params(params) 

385 config = merge_parameters( 

386 dest=self.config.copy(), 

387 source=params, 

388 required_keys=[ 

389 "subscription", 

390 "resourceGroup", 

391 "vmName", 

392 ], 

393 ) 

394 _LOG.info("Start VM: %s :: %s", config["vmName"], params) 

395 return self._azure_rest_api_post_helper( 

396 config, 

397 self._URL_START.format( 

398 subscription=config["subscription"], 

399 resource_group=config["resourceGroup"], 

400 vm_name=config["vmName"], 

401 ), 

402 ) 

403 

404 def stop_host(self, params: dict, force: bool = False) -> tuple[Status, dict]: 

405 """ 

406 Stops the VM on Azure by initiating a graceful shutdown. 

407 

408 Parameters 

409 ---------- 

410 params : dict 

411 Flat dictionary of (key, value) pairs of tunable parameters. 

412 force : bool 

413 If True, force stop the Host/VM. 

414 

415 Returns 

416 ------- 

417 result : (Status, dict) 

418 A pair of Status and result. The result is always {}. 

419 Status is one of {PENDING, SUCCEEDED, FAILED} 

420 """ 

421 params = self._set_default_params(params) 

422 config = merge_parameters( 

423 dest=self.config.copy(), 

424 source=params, 

425 required_keys=[ 

426 "subscription", 

427 "resourceGroup", 

428 "vmName", 

429 ], 

430 ) 

431 _LOG.info("Stop VM: %s", config["vmName"]) 

432 return self._azure_rest_api_post_helper( 

433 config, 

434 self._URL_STOP.format( 

435 subscription=config["subscription"], 

436 resource_group=config["resourceGroup"], 

437 vm_name=config["vmName"], 

438 ), 

439 ) 

440 

441 def shutdown(self, params: dict, force: bool = False) -> tuple["Status", dict]: 

442 return self.stop_host(params, force) 

443 

444 def restart_host(self, params: dict, force: bool = False) -> tuple[Status, dict]: 

445 """ 

446 Reboot the VM on Azure by initiating a graceful shutdown. 

447 

448 Parameters 

449 ---------- 

450 params : dict 

451 Flat dictionary of (key, value) pairs of tunable parameters. 

452 force : bool 

453 If True, force restart the Host/VM. 

454 

455 Returns 

456 ------- 

457 result : (Status, dict) 

458 A pair of Status and result. The result is always {}. 

459 Status is one of {PENDING, SUCCEEDED, FAILED} 

460 """ 

461 params = self._set_default_params(params) 

462 config = merge_parameters( 

463 dest=self.config.copy(), 

464 source=params, 

465 required_keys=[ 

466 "subscription", 

467 "resourceGroup", 

468 "vmName", 

469 ], 

470 ) 

471 _LOG.info("Reboot VM: %s", config["vmName"]) 

472 return self._azure_rest_api_post_helper( 

473 config, 

474 self._URL_REBOOT.format( 

475 subscription=config["subscription"], 

476 resource_group=config["resourceGroup"], 

477 vm_name=config["vmName"], 

478 ), 

479 ) 

480 

481 def reboot(self, params: dict, force: bool = False) -> tuple["Status", dict]: 

482 return self.restart_host(params, force) 

483 

484 def remote_exec( 

485 self, 

486 script: Iterable[str], 

487 config: dict, 

488 env_params: dict, 

489 ) -> tuple[Status, dict]: 

490 """ 

491 Run a command on Azure VM. 

492 

493 Parameters 

494 ---------- 

495 script : Iterable[str] 

496 A list of lines to execute as a script on a remote VM. 

497 config : dict 

498 Flat dictionary of (key, value) pairs of the Environment parameters. 

499 They usually come from `const_args` and `tunable_params` 

500 properties of the Environment. 

501 env_params : dict 

502 Parameters to pass as *shell* environment variables into the script. 

503 This is usually a subset of `config` with some possible conversions. 

504 

505 Returns 

506 ------- 

507 result : (Status, dict) 

508 A pair of Status and result. 

509 Status is one of {PENDING, SUCCEEDED, FAILED} 

510 """ 

511 config = self._set_default_params(config) 

512 config = merge_parameters( 

513 dest=self.config.copy(), 

514 source=config, 

515 required_keys=[ 

516 "subscription", 

517 "resourceGroup", 

518 "vmName", 

519 "commandName", 

520 "location", 

521 ], 

522 ) 

523 

524 if _LOG.isEnabledFor(logging.INFO): 

525 _LOG.info("Run a script on VM: %s\n %s", config["vmName"], "\n ".join(script)) 

526 

527 json_req = { 

528 "location": config["location"], 

529 "properties": { 

530 "source": {"script": "; ".join(script)}, 

531 "protectedParameters": [ 

532 {"name": key, "value": val} for (key, val) in env_params.items() 

533 ], 

534 "timeoutInSeconds": int(self._poll_timeout), 

535 "asyncExecution": True, 

536 }, 

537 } 

538 

539 url = self._URL_REXEC_RUN.format( 

540 subscription=config["subscription"], 

541 resource_group=config["resourceGroup"], 

542 vm_name=config["vmName"], 

543 command_name=config["commandName"], 

544 ) 

545 

546 if _LOG.isEnabledFor(logging.DEBUG): 

547 _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2)) 

548 

549 response = requests.put( 

550 url, 

551 json=json_req, 

552 headers=self._get_headers(), 

553 timeout=self._request_timeout, 

554 ) 

555 

556 if _LOG.isEnabledFor(logging.DEBUG): 

557 _LOG.debug( 

558 "Response: %s\n%s", 

559 response, 

560 json.dumps(response.json(), indent=2) if response.content else "", 

561 ) 

562 else: 

563 _LOG.info("Response: %s", response) 

564 

565 if response.status_code in {200, 201}: 

566 results_url = self._URL_REXEC_RESULT.format( 

567 subscription=config["subscription"], 

568 resource_group=config["resourceGroup"], 

569 vm_name=config["vmName"], 

570 command_name=config["commandName"], 

571 ) 

572 return ( 

573 Status.PENDING, 

574 {**config, "asyncResultsUrl": results_url}, 

575 ) 

576 else: 

577 _LOG.error("Response: %s :: %s", response, response.text) 

578 return (Status.FAILED, {}) 

579 

580 def _check_remote_exec_status(self, params: dict) -> tuple[Status, dict]: 

581 """ 

582 Checks the status of a pending remote execution on an Azure VM. 

583 

584 Parameters 

585 ---------- 

586 params: dict 

587 Flat dictionary of (key, value) pairs of tunable parameters. 

588 Must have the "asyncResultsUrl" key to get the results. 

589 If the key is not present, return Status.PENDING. 

590 

591 Returns 

592 ------- 

593 result : (Status, dict) 

594 A pair of Status and result. 

595 Status is one of {PENDING, RUNNING, SUCCEEDED, FAILED} 

596 Result is info on the operation runtime if SUCCEEDED, otherwise {}. 

597 """ 

598 url = params.get("asyncResultsUrl") 

599 if url is None: 

600 return Status.PENDING, {} 

601 

602 session = self._get_session(params) 

603 try: 

604 response = session.get(url, timeout=self._request_timeout) 

605 except requests.exceptions.ReadTimeout: 

606 _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url) 

607 return Status.RUNNING, {} 

608 except requests.exceptions.RequestException as ex: 

609 _LOG.exception("Error in request checking operation status", exc_info=ex) 

610 return (Status.FAILED, {}) 

611 

612 if _LOG.isEnabledFor(logging.DEBUG): 

613 _LOG.debug( 

614 "Response: %s\n%s", 

615 response, 

616 json.dumps(response.json(), indent=2) if response.content else "", 

617 ) 

618 

619 if response.status_code == 200: 

620 output = response.json() 

621 execution_state = ( 

622 output.get("properties", {}).get("instanceView", {}).get("executionState") 

623 ) 

624 if execution_state in {"Running", "Pending"}: 

625 return Status.RUNNING, {} 

626 elif execution_state == "Succeeded": 

627 return Status.SUCCEEDED, output 

628 

629 _LOG.error("Response: %s :: %s", response, response.text) 

630 return Status.FAILED, {} 

631 

632 def get_remote_exec_results(self, config: dict) -> tuple[Status, dict]: 

633 """ 

634 Get the results of the asynchronously running command. 

635 

636 Parameters 

637 ---------- 

638 config : dict 

639 Flat dictionary of (key, value) pairs of tunable parameters. 

640 Must have the "asyncResultsUrl" key to get the results. 

641 If the key is not present, return Status.PENDING. 

642 

643 Returns 

644 ------- 

645 result : (Status, dict) 

646 A pair of Status and result. 

647 Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT} 

648 A dict can have an "stdout" key with the remote output 

649 and an "stderr" key for errors / warnings. 

650 """ 

651 _LOG.info("Check the results on VM: %s", config.get("vmName")) 

652 (status, result) = self.wait_remote_exec_operation(config) 

653 _LOG.debug("Result: %s :: %s", status, result) 

654 if not status.is_succeeded(): 

655 # TODO: Extract the telemetry and status from stdout, if available 

656 return (status, result) 

657 

658 output = result.get("properties", {}).get("instanceView", {}) 

659 exit_code = output.get("exitCode") 

660 execution_state = output.get("executionState") 

661 outputs = output.get("output", "").strip().split("\n") 

662 errors = output.get("error", "").strip().split("\n") 

663 

664 if execution_state == "Succeeded" and exit_code == 0: 

665 status = Status.SUCCEEDED 

666 else: 

667 status = Status.FAILED 

668 

669 return ( 

670 status, 

671 { 

672 "stdout": outputs, 

673 "stderr": errors, 

674 "exitCode": exit_code, 

675 "startTimestamp": datetime.fromisoformat(output["startTime"]), 

676 "endTimestamp": datetime.fromisoformat(output["endTime"]), 

677 }, 

678 )