Coverage for mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py: 100%

18 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for RemoveEnv benchmark environment via local SSH test services.""" 

6 

7import os 

8from importlib.resources import files 

9 

10import numpy as np 

11import pytest 

12 

13from mlos_bench.services.config_persistence import ConfigPersistenceService 

14from mlos_bench.tests import requires_docker 

15from mlos_bench.tests.environments import check_env_success 

16from mlos_bench.tests.services.remote.ssh import SshTestServerInfo 

17from mlos_bench.tunables.tunable import TunableValue 

18from mlos_bench.tunables.tunable_groups import TunableGroups 

19 

20 

21@requires_docker 

22def test_remote_ssh_env(ssh_test_server: SshTestServerInfo) -> None: 

23 """Produce benchmark and telemetry data in a local script and read it.""" 

24 global_config: dict[str, TunableValue] = { 

25 "ssh_hostname": ssh_test_server.hostname, 

26 "ssh_port": ssh_test_server.get_port(), 

27 "ssh_username": ssh_test_server.username, 

28 "ssh_priv_key_path": ssh_test_server.id_rsa_path, 

29 } 

30 

31 service = ConfigPersistenceService( 

32 config={"config_path": [str(files("mlos_bench.tests.config"))]} 

33 ) 

34 config_path = service.resolve_path("environments/remote/test_ssh_env.jsonc") 

35 env = service.load_environment( 

36 config_path, 

37 TunableGroups(), 

38 global_config=global_config, 

39 service=service, 

40 ) 

41 

42 check_env_success( 

43 env, 

44 env.tunable_params, 

45 expected_results={ 

46 "hostname": ssh_test_server.service_name, 

47 "username": ssh_test_server.username, 

48 "score": 0.9, 

49 "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number" 

50 "test_param": "unset", 

51 "FOO": "unset", 

52 "ssh_username": "unset", 

53 }, 

54 expected_telemetry=[], 

55 ) 

56 assert not os.path.exists( 

57 os.path.join(os.getcwd(), "output-downloaded.csv") 

58 ), "output-downloaded.csv should have been cleaned up by temp_dir context" 

59 

60 

61if __name__ == "__main__": 

62 pytest.main(["-n1", __file__])