Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py: 89%

79 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""A collection FileShare functions for interacting with Azure File Shares.""" 

6 

7import logging 

8import os 

9from collections.abc import Callable 

10from typing import Any 

11 

12from azure.core.credentials import TokenCredential 

13from azure.core.exceptions import ResourceNotFoundError 

14from azure.storage.fileshare import ShareClient 

15 

16from mlos_bench.services.base_fileshare import FileShareService 

17from mlos_bench.services.base_service import Service 

18from mlos_bench.services.types.authenticator_type import SupportsAuth 

19from mlos_bench.util import check_required_params 

20 

21_LOG = logging.getLogger(__name__) 

22 

23 

24class AzureFileShareService(FileShareService): 

25 """Helper methods for interacting with Azure File Share.""" 

26 

27 _SHARE_URL = "https://{account_name}.file.core.windows.net/{fs_name}" 

28 

29 def __init__( 

30 self, 

31 config: dict[str, Any] | None = None, 

32 global_config: dict[str, Any] | None = None, 

33 parent: Service | None = None, 

34 methods: dict[str, Callable] | list[Callable] | None = None, 

35 ): 

36 """ 

37 Create a new file share Service for Azure environments with a given config. 

38 

39 Parameters 

40 ---------- 

41 config : dict 

42 Free-format dictionary that contains the file share configuration. 

43 It will be passed as a constructor parameter of the class 

44 specified by `class_name`. 

45 global_config : dict 

46 Free-format dictionary of global parameters. 

47 parent : Service 

48 Parent service that can provide mixin functions. 

49 methods : Union[dict[str, Callable], list[Callable], None] 

50 New methods to register with the service. 

51 """ 

52 super().__init__( 

53 config, 

54 global_config, 

55 parent, 

56 self.merge_methods(methods, [self.upload, self.download]), 

57 ) 

58 check_required_params( 

59 self.config, 

60 { 

61 "storageAccountName", 

62 "storageFileShareName", 

63 }, 

64 ) 

65 assert self._parent is not None and isinstance( 

66 self._parent, SupportsAuth 

67 ), "Authorization service not provided. Include service-auth.jsonc?" 

68 self._auth_service: SupportsAuth[TokenCredential] = self._parent 

69 self._share_client: ShareClient | None = None 

70 

71 def _get_share_client(self) -> ShareClient: 

72 """Get the Azure file share client object.""" 

73 if self._share_client is None: 

74 credential = self._auth_service.get_credential() 

75 assert isinstance( 

76 credential, TokenCredential 

77 ), f"Expected a TokenCredential, but got {type(credential)} instead." 

78 self._share_client = ShareClient.from_share_url( 

79 self._SHARE_URL.format( 

80 account_name=self.config["storageAccountName"], 

81 fs_name=self.config["storageFileShareName"], 

82 ), 

83 credential=credential, 

84 token_intent="backup", 

85 ) 

86 return self._share_client 

87 

88 def download( 

89 self, 

90 params: dict, 

91 remote_path: str, 

92 local_path: str, 

93 recursive: bool = True, 

94 ) -> None: 

95 super().download(params, remote_path, local_path, recursive) 

96 dir_client = self._get_share_client().get_directory_client(remote_path) 

97 if dir_client.exists(): 

98 os.makedirs(local_path, exist_ok=True) 

99 for content in dir_client.list_directories_and_files(): 

100 name = content["name"] 

101 local_target = f"{local_path}/{name}" 

102 remote_target = f"{remote_path}/{name}" 

103 if recursive or not content["is_directory"]: 

104 self.download(params, remote_target, local_target, recursive) 

105 else: # Must be a file 

106 # Ensure parent folders exist 

107 folder, _ = os.path.split(local_path) 

108 os.makedirs(folder, exist_ok=True) 

109 file_client = self._get_share_client().get_file_client(remote_path) 

110 try: 

111 data = file_client.download_file() 

112 with open(local_path, "wb") as output_file: 

113 _LOG.debug("Download file: %s -> %s", remote_path, local_path) 

114 data.readinto(output_file) 

115 except ResourceNotFoundError as ex: 

116 # Translate into non-Azure exception: 

117 raise FileNotFoundError(f"Cannot download: {remote_path}") from ex 

118 

119 def upload( 

120 self, 

121 params: dict, 

122 local_path: str, 

123 remote_path: str, 

124 recursive: bool = True, 

125 ) -> None: 

126 super().upload(params, local_path, remote_path, recursive) 

127 self._upload(local_path, remote_path, recursive, set()) 

128 

129 def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: set[str]) -> None: 

130 """ 

131 Upload contents from a local path to an Azure file share. This method is called 

132 from `.upload()` above. We need it to avoid exposing the `seen` parameter and to 

133 make `.upload()` match the base class' virtual method. 

134 

135 Parameters 

136 ---------- 

137 local_path : str 

138 Path to the local directory to upload contents from, either a file or directory. 

139 remote_path : str 

140 Path in the remote file share to store the uploaded content to. 

141 recursive : bool 

142 If False, ignore the subdirectories; 

143 if True (the default), upload the entire directory tree. 

144 seen: set[str] 

145 Helper set for keeping track of visited directories to break circular paths. 

146 """ 

147 local_path = os.path.abspath(local_path) 

148 if local_path in seen: 

149 _LOG.warning("Loop in directories, skipping '%s'", local_path) 

150 return 

151 seen.add(local_path) 

152 

153 if os.path.isdir(local_path): 

154 self._remote_makedirs(remote_path) 

155 for entry in os.scandir(local_path): 

156 name = entry.name 

157 local_target = f"{local_path}/{name}" 

158 remote_target = f"{remote_path}/{name}" 

159 if recursive or not entry.is_dir(): 

160 self._upload(local_target, remote_target, recursive, seen) 

161 else: 

162 # Ensure parent folders exist 

163 folder, _ = os.path.split(remote_path) 

164 self._remote_makedirs(folder) 

165 file_client = self._get_share_client().get_file_client(remote_path) 

166 with open(local_path, "rb") as file_data: 

167 _LOG.debug("Upload file: %s -> %s", local_path, remote_path) 

168 file_client.upload_file(file_data) 

169 

170 def _remote_makedirs(self, remote_path: str) -> None: 

171 """ 

172 Create remote directories for the entire path. Succeeds even some or all 

173 directories along the path already exist. 

174 

175 Parameters 

176 ---------- 

177 remote_path : str 

178 Path in the remote file share to create. 

179 """ 

180 path = "" 

181 for folder in remote_path.replace("\\", "/").split("/"): 

182 if not folder: 

183 continue 

184 path += folder + "/" 

185 dir_client = self._get_share_client().get_directory_client(path) 

186 if not dir_client.exists(): 

187 dir_client.create_directory()