Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py: 88%
78 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""A collection FileShare functions for interacting with Azure File Shares."""
7import logging
8import os
9from typing import Any, Callable, Dict, List, Optional, Set, Union
11from azure.core.credentials import TokenCredential
12from azure.core.exceptions import ResourceNotFoundError
13from azure.storage.fileshare import ShareClient
15from mlos_bench.services.base_fileshare import FileShareService
16from mlos_bench.services.base_service import Service
17from mlos_bench.services.types.authenticator_type import SupportsAuth
18from mlos_bench.util import check_required_params
20_LOG = logging.getLogger(__name__)
23class AzureFileShareService(FileShareService):
24 """Helper methods for interacting with Azure File Share."""
26 _SHARE_URL = "https://{account_name}.file.core.windows.net/{fs_name}"
28 def __init__(
29 self,
30 config: Optional[Dict[str, Any]] = None,
31 global_config: Optional[Dict[str, Any]] = None,
32 parent: Optional[Service] = None,
33 methods: Union[Dict[str, Callable], List[Callable], None] = None,
34 ):
35 """
36 Create a new file share Service for Azure environments with a given config.
38 Parameters
39 ----------
40 config : dict
41 Free-format dictionary that contains the file share configuration.
42 It will be passed as a constructor parameter of the class
43 specified by `class_name`.
44 global_config : dict
45 Free-format dictionary of global parameters.
46 parent : Service
47 Parent service that can provide mixin functions.
48 methods : Union[Dict[str, Callable], List[Callable], None]
49 New methods to register with the service.
50 """
51 super().__init__(
52 config,
53 global_config,
54 parent,
55 self.merge_methods(methods, [self.upload, self.download]),
56 )
57 check_required_params(
58 self.config,
59 {
60 "storageAccountName",
61 "storageFileShareName",
62 },
63 )
64 assert self._parent is not None and isinstance(
65 self._parent, SupportsAuth
66 ), "Authorization service not provided. Include service-auth.jsonc?"
67 self._auth_service: SupportsAuth[TokenCredential] = self._parent
68 self._share_client: Optional[ShareClient] = None
70 def _get_share_client(self) -> ShareClient:
71 """Get the Azure file share client object."""
72 if self._share_client is None:
73 credential = self._auth_service.get_credential()
74 assert isinstance(
75 credential, TokenCredential
76 ), f"Expected a TokenCredential, but got {type(credential)} instead."
77 self._share_client = ShareClient.from_share_url(
78 self._SHARE_URL.format(
79 account_name=self.config["storageAccountName"],
80 fs_name=self.config["storageFileShareName"],
81 ),
82 credential=credential,
83 token_intent="backup",
84 )
85 return self._share_client
87 def download(
88 self,
89 params: dict,
90 remote_path: str,
91 local_path: str,
92 recursive: bool = True,
93 ) -> None:
94 super().download(params, remote_path, local_path, recursive)
95 dir_client = self._get_share_client().get_directory_client(remote_path)
96 if dir_client.exists():
97 os.makedirs(local_path, exist_ok=True)
98 for content in dir_client.list_directories_and_files():
99 name = content["name"]
100 local_target = f"{local_path}/{name}"
101 remote_target = f"{remote_path}/{name}"
102 if recursive or not content["is_directory"]:
103 self.download(params, remote_target, local_target, recursive)
104 else: # Must be a file
105 # Ensure parent folders exist
106 folder, _ = os.path.split(local_path)
107 os.makedirs(folder, exist_ok=True)
108 file_client = self._get_share_client().get_file_client(remote_path)
109 try:
110 data = file_client.download_file()
111 with open(local_path, "wb") as output_file:
112 _LOG.debug("Download file: %s -> %s", remote_path, local_path)
113 data.readinto(output_file)
114 except ResourceNotFoundError as ex:
115 # Translate into non-Azure exception:
116 raise FileNotFoundError(f"Cannot download: {remote_path}") from ex
118 def upload(
119 self,
120 params: dict,
121 local_path: str,
122 remote_path: str,
123 recursive: bool = True,
124 ) -> None:
125 super().upload(params, local_path, remote_path, recursive)
126 self._upload(local_path, remote_path, recursive, set())
128 def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]) -> None:
129 """
130 Upload contents from a local path to an Azure file share. This method is called
131 from `.upload()` above. We need it to avoid exposing the `seen` parameter and to
132 make `.upload()` match the base class' virtual method.
134 Parameters
135 ----------
136 local_path : str
137 Path to the local directory to upload contents from, either a file or directory.
138 remote_path : str
139 Path in the remote file share to store the uploaded content to.
140 recursive : bool
141 If False, ignore the subdirectories;
142 if True (the default), upload the entire directory tree.
143 seen: Set[str]
144 Helper set for keeping track of visited directories to break circular paths.
145 """
146 local_path = os.path.abspath(local_path)
147 if local_path in seen:
148 _LOG.warning("Loop in directories, skipping '%s'", local_path)
149 return
150 seen.add(local_path)
152 if os.path.isdir(local_path):
153 self._remote_makedirs(remote_path)
154 for entry in os.scandir(local_path):
155 name = entry.name
156 local_target = f"{local_path}/{name}"
157 remote_target = f"{remote_path}/{name}"
158 if recursive or not entry.is_dir():
159 self._upload(local_target, remote_target, recursive, seen)
160 else:
161 # Ensure parent folders exist
162 folder, _ = os.path.split(remote_path)
163 self._remote_makedirs(folder)
164 file_client = self._get_share_client().get_file_client(remote_path)
165 with open(local_path, "rb") as file_data:
166 _LOG.debug("Upload file: %s -> %s", local_path, remote_path)
167 file_client.upload_file(file_data)
169 def _remote_makedirs(self, remote_path: str) -> None:
170 """
171 Create remote directories for the entire path. Succeeds even some or all
172 directories along the path already exist.
174 Parameters
175 ----------
176 remote_path : str
177 Path in the remote file share to create.
178 """
179 path = ""
180 for folder in remote_path.replace("\\", "/").split("/"):
181 if not folder:
182 continue
183 path += folder + "/"
184 dir_client = self._get_share_client().get_directory_client(path)
185 if not dir_client.exists():
186 dir_client.create_directory()