Coverage for mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py: 100%

84 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tests for mlos_bench.services.remote.ssh.ssh_services.""" 

6 

7import os 

8import tempfile 

9from collections.abc import Generator 

10from contextlib import contextmanager 

11from os.path import basename 

12from pathlib import Path 

13from tempfile import _TemporaryFileWrapper # pylint: disable=import-private-name 

14from typing import Any 

15 

16import pytest 

17 

18from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService 

19from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService 

20from mlos_bench.tests import are_dir_trees_equal, requires_docker 

21from mlos_bench.tests.services.remote.ssh import SshTestServerInfo 

22from mlos_bench.util import path_join 

23 

24 

25@contextmanager 

26def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper]: 

27 """ 

28 Provides a context manager for a temporary file that can be closed and still 

29 unlinked. 

30 

31 Since Windows doesn't allow us to reopen the file while it's still open we 

32 need to handle deletion ourselves separately. 

33 

34 Parameters 

35 ---------- 

36 kwargs: dict 

37 Args to pass to NamedTemporaryFile constructor. 

38 

39 Returns 

40 ------- 

41 context manager for a temporary file 

42 """ 

43 fname = None 

44 try: 

45 with tempfile.NamedTemporaryFile(delete=False, **kwargs) as temp_file: 

46 fname = temp_file.name 

47 yield temp_file 

48 finally: 

49 if fname: 

50 os.unlink(fname) 

51 

52 

53@requires_docker 

54def test_ssh_fileshare_single_file( 

55 ssh_test_server: SshTestServerInfo, 

56 ssh_fileshare_service: SshFileShareService, 

57) -> None: 

58 """Test the SshFileShareService single file download/upload.""" 

59 with ssh_fileshare_service: 

60 config = ssh_test_server.to_ssh_service_config() 

61 

62 remote_file_path = "/tmp/test_ssh_fileshare_single_file" 

63 lines = [ 

64 "foo", 

65 "bar", 

66 ] 

67 lines = [line + "\n" for line in lines] 

68 

69 # 1. Write a local file and upload it. 

70 with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: 

71 temp_file.writelines(lines) 

72 temp_file.flush() 

73 temp_file.close() 

74 

75 ssh_fileshare_service.upload( 

76 params=config, 

77 local_path=temp_file.name, 

78 remote_path=remote_file_path, 

79 ) 

80 

81 # 2. Download the remote file and compare the contents. 

82 with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: 

83 temp_file.close() 

84 ssh_fileshare_service.download( 

85 params=config, 

86 remote_path=remote_file_path, 

87 local_path=temp_file.name, 

88 ) 

89 # Download will replace the inode at that name, so we need to reopen the file. 

90 with open(temp_file.name, encoding="utf-8") as temp_file_h: 

91 read_lines = temp_file_h.readlines() 

92 assert read_lines == lines 

93 

94 

95@requires_docker 

96def test_ssh_fileshare_recursive( 

97 ssh_test_server: SshTestServerInfo, 

98 ssh_fileshare_service: SshFileShareService, 

99) -> None: 

100 """Test the SshFileShareService recursive download/upload.""" 

101 with ssh_fileshare_service: 

102 config = ssh_test_server.to_ssh_service_config() 

103 

104 remote_file_path = "/tmp/test_ssh_fileshare_recursive_dir" 

105 files_lines: dict[str, list[str]] = { 

106 "file-a.txt": [ 

107 "a", 

108 "1", 

109 ], 

110 "file-b.txt": [ 

111 "b", 

112 "2", 

113 ], 

114 "subdir/foo.txt": [ 

115 "foo", 

116 "bar", 

117 ], 

118 } 

119 files_lines = { 

120 path: [line + "\n" for line in lines] for (path, lines) in files_lines.items() 

121 } 

122 

123 with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2: 

124 # Setup the directory structure. 

125 for file_path, lines in files_lines.items(): 

126 path = Path(tempdir1, file_path) 

127 path.parent.mkdir(parents=True, exist_ok=True) 

128 with open(path, mode="w+t", encoding="utf-8") as temp_file: 

129 temp_file.writelines(lines) 

130 temp_file.flush() 

131 assert os.path.getsize(path) > 0 

132 

133 # Copy that structure over to the remote server. 

134 ssh_fileshare_service.upload( 

135 params=config, 

136 local_path=f"{tempdir1}", 

137 remote_path=f"{remote_file_path}", 

138 recursive=True, 

139 ) 

140 

141 # Copy the remote structure back to the local machine. 

142 ssh_fileshare_service.download( 

143 params=config, 

144 remote_path=f"{remote_file_path}", 

145 local_path=f"{tempdir2}", 

146 recursive=True, 

147 ) 

148 

149 # Compare both. 

150 # Note: remote dir name is appended to target. 

151 assert are_dir_trees_equal(tempdir1, path_join(tempdir2, basename(remote_file_path))) 

152 

153 

154@requires_docker 

155def test_ssh_fileshare_download_file_dne( 

156 ssh_test_server: SshTestServerInfo, 

157 ssh_fileshare_service: SshFileShareService, 

158) -> None: 

159 """Test the SshFileShareService single file download that doesn't exist.""" 

160 with ssh_fileshare_service: 

161 config = ssh_test_server.to_ssh_service_config() 

162 

163 canary_str = "canary" 

164 

165 with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: 

166 temp_file.writelines([canary_str]) 

167 temp_file.flush() 

168 temp_file.close() 

169 

170 with pytest.raises(FileNotFoundError): 

171 ssh_fileshare_service.download( 

172 params=config, 

173 remote_path="/tmp/file-dne.txt", 

174 local_path=temp_file.name, 

175 ) 

176 with open(temp_file.name, encoding="utf-8") as temp_file_h: 

177 read_lines = temp_file_h.readlines() 

178 assert read_lines == [canary_str] 

179 

180 

181@requires_docker 

182def test_ssh_fileshare_upload_file_dne( 

183 ssh_test_server: SshTestServerInfo, 

184 ssh_host_service: SshHostService, 

185 ssh_fileshare_service: SshFileShareService, 

186) -> None: 

187 """Test the SshFileShareService single file upload that doesn't exist.""" 

188 with ssh_host_service, ssh_fileshare_service: 

189 config = ssh_test_server.to_ssh_service_config() 

190 

191 path = "/tmp/upload-file-src-dne.txt" 

192 with pytest.raises(OSError): 

193 ssh_fileshare_service.upload( 

194 params=config, 

195 remote_path=path, 

196 local_path=path, 

197 ) 

198 (status, results) = ssh_host_service.remote_exec( 

199 script=[f"[[ ! -e {path} ]]; echo $?"], 

200 config=config, 

201 env_params={}, 

202 ) 

203 (status, results) = ssh_host_service.get_remote_exec_results(results) 

204 assert status.is_succeeded() 

205 assert str(results["stdout"]).strip() == "0"