Coverage for mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py: 95%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A collection functions for interacting with SSH servers as file shares.""" 

6 

7import logging 

8from enum import Enum 

9 

10from asyncssh import SFTPError, SFTPFailure, SFTPNoSuchFile, SSHClientConnection, scp 

11 

12from mlos_bench.services.base_fileshare import FileShareService 

13from mlos_bench.services.remote.ssh.ssh_service import SshService 

14from mlos_bench.util import merge_parameters 

15 

16_LOG = logging.getLogger(__name__) 

17 

18 

19class CopyMode(Enum): 

20 """Copy mode enum.""" 

21 

22 DOWNLOAD = 1 

23 UPLOAD = 2 

24 

25 

26class SshFileShareService(FileShareService, SshService): 

27 """A collection of functions for interacting with SSH servers as file shares.""" 

28 

29 # pylint: disable=too-many-ancestors 

30 

31 async def _start_file_copy( 

32 self, 

33 params: dict, 

34 mode: CopyMode, 

35 local_path: str, 

36 remote_path: str, 

37 recursive: bool = True, 

38 ) -> None: 

39 # pylint: disable=too-many-arguments,too-many-positional-arguments 

40 """ 

41 Starts a file copy operation. 

42 

43 Parameters 

44 ---------- 

45 params : dict 

46 Flat dictionary of (key, value) pairs of parameters (used for 

47 establishing the connection). 

48 mode : CopyMode 

49 Whether to download or upload the file. 

50 local_path : str 

51 Local path to the file/dir. 

52 remote_path : str 

53 Remote path to the file/dir. 

54 recursive : bool 

55 Whether to copy recursively. By default True. 

56 

57 Raises 

58 ------ 

59 OSError 

60 If the local OS returns an error. 

61 SFTPError 

62 If the remote OS returns an error. 

63 FileNotFoundError 

64 If the remote file does not exist, the SFTPError is converted to a FileNotFoundError. 

65 """ 

66 connection, _ = await self._get_client_connection(params) 

67 srcpaths: str | tuple[SSHClientConnection, str] 

68 dstpath: str | tuple[SSHClientConnection, str] 

69 if mode == CopyMode.DOWNLOAD: 

70 srcpaths = (connection, remote_path) 

71 dstpath = local_path 

72 elif mode == CopyMode.UPLOAD: 

73 srcpaths = local_path 

74 dstpath = (connection, remote_path) 

75 else: 

76 raise ValueError(f"Unknown copy mode: {mode}") 

77 return await scp(srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True) 

78 

79 def download( 

80 self, 

81 params: dict, 

82 remote_path: str, 

83 local_path: str, 

84 recursive: bool = True, 

85 ) -> None: 

86 params = merge_parameters( 

87 dest=self.config.copy(), 

88 source=params, 

89 required_keys=[ 

90 "ssh_hostname", 

91 ], 

92 ) 

93 super().download(params, remote_path, local_path, recursive) 

94 file_copy_future = self._run_coroutine( 

95 self._start_file_copy( 

96 params, 

97 CopyMode.DOWNLOAD, 

98 local_path, 

99 remote_path, 

100 recursive, 

101 ) 

102 ) 

103 try: 

104 file_copy_future.result() 

105 except (OSError, SFTPError) as ex: 

106 _LOG.error( 

107 "Failed to download %s to %s from %s: %s", 

108 remote_path, 

109 local_path, 

110 params, 

111 ex, 

112 ) 

113 if isinstance(ex, SFTPNoSuchFile) or ( 

114 isinstance(ex, SFTPFailure) 

115 and ex.code == 4 

116 and any( 

117 msg.lower() in ex.reason.lower() 

118 for msg in ("File not found", "No such file or directory") 

119 ) 

120 ): 

121 _LOG.warning("File %s does not exist on %s", remote_path, params) 

122 raise FileNotFoundError(f"File {remote_path} does not exist on {params}") from ex 

123 raise ex 

124 

125 def upload( 

126 self, 

127 params: dict, 

128 local_path: str, 

129 remote_path: str, 

130 recursive: bool = True, 

131 ) -> None: 

132 params = merge_parameters( 

133 dest=self.config.copy(), 

134 source=params, 

135 required_keys=[ 

136 "ssh_hostname", 

137 ], 

138 ) 

139 super().upload(params, local_path, remote_path, recursive) 

140 file_copy_future = self._run_coroutine( 

141 self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive) 

142 ) 

143 try: 

144 file_copy_future.result() 

145 except (OSError, SFTPError) as ex: 

146 _LOG.error("Failed to upload %s to %s on %s: %s", local_path, remote_path, params, ex) 

147 raise ex