Coverage for mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py: 96%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Fixtures for the SSH service tests. 

7 

8Note: these are not in the conftest.py file because they are also used by remote_ssh_env_test.py 

9""" 

10 

11import os 

12import sys 

13import tempfile 

14from subprocess import run 

15from typing import Generator 

16 

17import pytest 

18from pytest_docker.plugin import Services as DockerServices 

19 

20from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService 

21from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService 

22from mlos_bench.tests import resolve_host_name 

23from mlos_bench.tests.services.remote.ssh import ( 

24 ALT_TEST_SERVER_NAME, 

25 REBOOT_TEST_SERVER_NAME, 

26 SSH_TEST_SERVER_NAME, 

27 SshTestServerInfo, 

28 wait_docker_service_socket, 

29) 

30 

31# pylint: disable=redefined-outer-name 

32 

33HOST_DOCKER_NAME = "host.docker.internal" 

34 

35 

36@pytest.fixture(scope="session") 

37def ssh_test_server_hostname() -> str: 

38 """Returns the local hostname to use to connect to the test ssh server.""" 

39 if sys.platform != "win32" and resolve_host_name(HOST_DOCKER_NAME): 

40 # On Linux, if we're running in a docker container, we can use the 

41 # --add-host (extra_hosts in docker-compose.yml) to refer to the host IP. 

42 return HOST_DOCKER_NAME 

43 # Docker (Desktop) for Windows (WSL2) uses a special networking magic 

44 # to refer to the host machine as `localhost` when exposing ports. 

45 # In all other cases, assume we're executing directly inside conda on the host. 

46 return "localhost" 

47 

48 

49@pytest.fixture(scope="session") 

50def ssh_test_server( 

51 ssh_test_server_hostname: str, 

52 docker_compose_project_name: str, 

53 locked_docker_services: DockerServices, 

54) -> Generator[SshTestServerInfo, None, None]: 

55 """ 

56 Fixture for getting the ssh test server services setup via docker-compose using 

57 pytest-docker. 

58 

59 Yields the (hostname, port, username, id_rsa_path) of the test server. 

60 

61 Once the session is over, the docker containers are torn down, and the temporary 

62 file holding the dynamically generated private key of the test server is deleted. 

63 """ 

64 # Get a copy of the ssh id_rsa key from the test ssh server. 

65 with tempfile.NamedTemporaryFile() as id_rsa_file: 

66 ssh_test_server_info = SshTestServerInfo( 

67 compose_project_name=docker_compose_project_name, 

68 service_name=SSH_TEST_SERVER_NAME, 

69 hostname=ssh_test_server_hostname, 

70 username="root", 

71 id_rsa_path=id_rsa_file.name, 

72 ) 

73 wait_docker_service_socket( 

74 locked_docker_services, ssh_test_server_info.hostname, ssh_test_server_info.get_port() 

75 ) 

76 id_rsa_src = f"/{ssh_test_server_info.username}/.ssh/id_rsa" 

77 docker_cp_cmd = ( 

78 f"docker compose -p {docker_compose_project_name} " 

79 f"cp {SSH_TEST_SERVER_NAME}:{id_rsa_src} {id_rsa_file.name}" 

80 ) 

81 cmd = run( 

82 docker_cp_cmd.split(), 

83 check=True, 

84 cwd=os.path.dirname(__file__), 

85 capture_output=True, 

86 text=True, 

87 ) 

88 if cmd.returncode != 0: 

89 raise RuntimeError( 

90 f"Failed to copy ssh key from {SSH_TEST_SERVER_NAME} container " 

91 + f"[return={cmd.returncode}]: {str(cmd.stderr)}" 

92 ) 

93 os.chmod(id_rsa_file.name, 0o600) 

94 yield ssh_test_server_info 

95 # NamedTempFile deleted on context exit 

96 

97 

98@pytest.fixture(scope="session") 

99def alt_test_server( 

100 ssh_test_server: SshTestServerInfo, 

101 locked_docker_services: DockerServices, 

102) -> SshTestServerInfo: 

103 """ 

104 Fixture for getting the second ssh test server info from the docker-compose.yml. 

105 

106 See additional notes in the ssh_test_server fixture above. 

107 """ 

108 # Note: The alt-server uses the same image as the ssh-server container, so 

109 # the id_rsa key and username should all match. 

110 # Only the host port it is allocate is different. 

111 alt_test_server_info = SshTestServerInfo( 

112 compose_project_name=ssh_test_server.compose_project_name, 

113 service_name=ALT_TEST_SERVER_NAME, 

114 hostname=ssh_test_server.hostname, 

115 username=ssh_test_server.username, 

116 id_rsa_path=ssh_test_server.id_rsa_path, 

117 ) 

118 wait_docker_service_socket( 

119 locked_docker_services, alt_test_server_info.hostname, alt_test_server_info.get_port() 

120 ) 

121 return alt_test_server_info 

122 

123 

124@pytest.fixture(scope="session") 

125def reboot_test_server( 

126 ssh_test_server: SshTestServerInfo, 

127 locked_docker_services: DockerServices, 

128) -> SshTestServerInfo: 

129 """ 

130 Fixture for getting the third ssh test server info from the docker-compose.yml. 

131 

132 See additional notes in the ssh_test_server fixture above. 

133 """ 

134 # Note: The reboot-server uses the same image as the ssh-server container, so 

135 # the id_rsa key and username should all match. 

136 # Only the host port it is allocate is different. 

137 reboot_test_server_info = SshTestServerInfo( 

138 compose_project_name=ssh_test_server.compose_project_name, 

139 service_name=REBOOT_TEST_SERVER_NAME, 

140 hostname=ssh_test_server.hostname, 

141 username=ssh_test_server.username, 

142 id_rsa_path=ssh_test_server.id_rsa_path, 

143 ) 

144 wait_docker_service_socket( 

145 locked_docker_services, 

146 reboot_test_server_info.hostname, 

147 reboot_test_server_info.get_port(), 

148 ) 

149 return reboot_test_server_info 

150 

151 

152@pytest.fixture 

153def ssh_host_service(ssh_test_server: SshTestServerInfo) -> SshHostService: 

154 """Generic SshHostService fixture.""" 

155 return SshHostService( 

156 config={ 

157 "ssh_username": ssh_test_server.username, 

158 "ssh_priv_key_path": ssh_test_server.id_rsa_path, 

159 }, 

160 ) 

161 

162 

163@pytest.fixture 

164def ssh_fileshare_service() -> SshFileShareService: 

165 """Generic SshFileShareService fixture.""" 

166 return SshFileShareService( 

167 config={ 

168 # Left blank to make sure we test per connection overrides. 

169 }, 

170 )