Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py: 88%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A collection FileShare functions for interacting with Azure File Shares.""" 

6 

7import logging 

8import os 

9from typing import Any, Callable, Dict, List, Optional, Set, Union 

10 

11from azure.core.credentials import TokenCredential 

12from azure.core.exceptions import ResourceNotFoundError 

13from azure.storage.fileshare import ShareClient 

14 

15from mlos_bench.services.base_fileshare import FileShareService 

16from mlos_bench.services.base_service import Service 

17from mlos_bench.services.types.authenticator_type import SupportsAuth 

18from mlos_bench.util import check_required_params 

19 

20_LOG = logging.getLogger(__name__) 

21 

22 

23class AzureFileShareService(FileShareService): 

24 """Helper methods for interacting with Azure File Share.""" 

25 

26 _SHARE_URL = "https://{account_name}.file.core.windows.net/{fs_name}" 

27 

28 def __init__( 

29 self, 

30 config: Optional[Dict[str, Any]] = None, 

31 global_config: Optional[Dict[str, Any]] = None, 

32 parent: Optional[Service] = None, 

33 methods: Union[Dict[str, Callable], List[Callable], None] = None, 

34 ): 

35 """ 

36 Create a new file share Service for Azure environments with a given config. 

37 

38 Parameters 

39 ---------- 

40 config : dict 

41 Free-format dictionary that contains the file share configuration. 

42 It will be passed as a constructor parameter of the class 

43 specified by `class_name`. 

44 global_config : dict 

45 Free-format dictionary of global parameters. 

46 parent : Service 

47 Parent service that can provide mixin functions. 

48 methods : Union[Dict[str, Callable], List[Callable], None] 

49 New methods to register with the service. 

50 """ 

51 super().__init__( 

52 config, 

53 global_config, 

54 parent, 

55 self.merge_methods(methods, [self.upload, self.download]), 

56 ) 

57 check_required_params( 

58 self.config, 

59 { 

60 "storageAccountName", 

61 "storageFileShareName", 

62 }, 

63 ) 

64 assert self._parent is not None and isinstance( 

65 self._parent, SupportsAuth 

66 ), "Authorization service not provided. Include service-auth.jsonc?" 

67 self._auth_service: SupportsAuth[TokenCredential] = self._parent 

68 self._share_client: Optional[ShareClient] = None 

69 

70 def _get_share_client(self) -> ShareClient: 

71 """Get the Azure file share client object.""" 

72 if self._share_client is None: 

73 credential = self._auth_service.get_credential() 

74 assert isinstance( 

75 credential, TokenCredential 

76 ), f"Expected a TokenCredential, but got {type(credential)} instead." 

77 self._share_client = ShareClient.from_share_url( 

78 self._SHARE_URL.format( 

79 account_name=self.config["storageAccountName"], 

80 fs_name=self.config["storageFileShareName"], 

81 ), 

82 credential=credential, 

83 token_intent="backup", 

84 ) 

85 return self._share_client 

86 

87 def download( 

88 self, 

89 params: dict, 

90 remote_path: str, 

91 local_path: str, 

92 recursive: bool = True, 

93 ) -> None: 

94 super().download(params, remote_path, local_path, recursive) 

95 dir_client = self._get_share_client().get_directory_client(remote_path) 

96 if dir_client.exists(): 

97 os.makedirs(local_path, exist_ok=True) 

98 for content in dir_client.list_directories_and_files(): 

99 name = content["name"] 

100 local_target = f"{local_path}/{name}" 

101 remote_target = f"{remote_path}/{name}" 

102 if recursive or not content["is_directory"]: 

103 self.download(params, remote_target, local_target, recursive) 

104 else: # Must be a file 

105 # Ensure parent folders exist 

106 folder, _ = os.path.split(local_path) 

107 os.makedirs(folder, exist_ok=True) 

108 file_client = self._get_share_client().get_file_client(remote_path) 

109 try: 

110 data = file_client.download_file() 

111 with open(local_path, "wb") as output_file: 

112 _LOG.debug("Download file: %s -> %s", remote_path, local_path) 

113 data.readinto(output_file) # type: ignore[no-untyped-call] 

114 except ResourceNotFoundError as ex: 

115 # Translate into non-Azure exception: 

116 raise FileNotFoundError(f"Cannot download: {remote_path}") from ex 

117 

118 def upload( 

119 self, 

120 params: dict, 

121 local_path: str, 

122 remote_path: str, 

123 recursive: bool = True, 

124 ) -> None: 

125 super().upload(params, local_path, remote_path, recursive) 

126 self._upload(local_path, remote_path, recursive, set()) 

127 

128 def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]) -> None: 

129 """ 

130 Upload contents from a local path to an Azure file share. This method is called 

131 from `.upload()` above. We need it to avoid exposing the `seen` parameter and to 

132 make `.upload()` match the base class' virtual method. 

133 

134 Parameters 

135 ---------- 

136 local_path : str 

137 Path to the local directory to upload contents from, either a file or directory. 

138 remote_path : str 

139 Path in the remote file share to store the uploaded content to. 

140 recursive : bool 

141 If False, ignore the subdirectories; 

142 if True (the default), upload the entire directory tree. 

143 seen: Set[str] 

144 Helper set for keeping track of visited directories to break circular paths. 

145 """ 

146 local_path = os.path.abspath(local_path) 

147 if local_path in seen: 

148 _LOG.warning("Loop in directories, skipping '%s'", local_path) 

149 return 

150 seen.add(local_path) 

151 

152 if os.path.isdir(local_path): 

153 self._remote_makedirs(remote_path) 

154 for entry in os.scandir(local_path): 

155 name = entry.name 

156 local_target = f"{local_path}/{name}" 

157 remote_target = f"{remote_path}/{name}" 

158 if recursive or not entry.is_dir(): 

159 self._upload(local_target, remote_target, recursive, seen) 

160 else: 

161 # Ensure parent folders exist 

162 folder, _ = os.path.split(remote_path) 

163 self._remote_makedirs(folder) 

164 file_client = self._get_share_client().get_file_client(remote_path) 

165 with open(local_path, "rb") as file_data: 

166 _LOG.debug("Upload file: %s -> %s", local_path, remote_path) 

167 file_client.upload_file(file_data) 

168 

169 def _remote_makedirs(self, remote_path: str) -> None: 

170 """ 

171 Create remote directories for the entire path. Succeeds even some or all 

172 directories along the path already exist. 

173 

174 Parameters 

175 ---------- 

176 remote_path : str 

177 Path in the remote file share to create. 

178 """ 

179 path = "" 

180 for folder in remote_path.replace("\\", "/").split("/"): 

181 if not folder: 

182 continue 

183 path += folder + "/" 

184 dir_client = self._get_share_client().get_directory_client(path) 

185 if not dir_client.exists(): 

186 dir_client.create_directory()