Coverage for mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py: 89%
79 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""A collection FileShare functions for interacting with Azure File Shares."""
7import logging
8import os
9from collections.abc import Callable
10from typing import Any
12from azure.core.credentials import TokenCredential
13from azure.core.exceptions import ResourceNotFoundError
14from azure.storage.fileshare import ShareClient
16from mlos_bench.services.base_fileshare import FileShareService
17from mlos_bench.services.base_service import Service
18from mlos_bench.services.types.authenticator_type import SupportsAuth
19from mlos_bench.util import check_required_params
21_LOG = logging.getLogger(__name__)
24class AzureFileShareService(FileShareService):
25 """Helper methods for interacting with Azure File Share."""
27 _SHARE_URL = "https://{account_name}.file.core.windows.net/{fs_name}"
29 def __init__(
30 self,
31 config: dict[str, Any] | None = None,
32 global_config: dict[str, Any] | None = None,
33 parent: Service | None = None,
34 methods: dict[str, Callable] | list[Callable] | None = None,
35 ):
36 """
37 Create a new file share Service for Azure environments with a given config.
39 Parameters
40 ----------
41 config : dict
42 Free-format dictionary that contains the file share configuration.
43 It will be passed as a constructor parameter of the class
44 specified by `class_name`.
45 global_config : dict
46 Free-format dictionary of global parameters.
47 parent : Service
48 Parent service that can provide mixin functions.
49 methods : Union[dict[str, Callable], list[Callable], None]
50 New methods to register with the service.
51 """
52 super().__init__(
53 config,
54 global_config,
55 parent,
56 self.merge_methods(methods, [self.upload, self.download]),
57 )
58 check_required_params(
59 self.config,
60 {
61 "storageAccountName",
62 "storageFileShareName",
63 },
64 )
65 assert self._parent is not None and isinstance(
66 self._parent, SupportsAuth
67 ), "Authorization service not provided. Include service-auth.jsonc?"
68 self._auth_service: SupportsAuth[TokenCredential] = self._parent
69 self._share_client: ShareClient | None = None
71 def _get_share_client(self) -> ShareClient:
72 """Get the Azure file share client object."""
73 if self._share_client is None:
74 credential = self._auth_service.get_credential()
75 assert isinstance(
76 credential, TokenCredential
77 ), f"Expected a TokenCredential, but got {type(credential)} instead."
78 self._share_client = ShareClient.from_share_url(
79 self._SHARE_URL.format(
80 account_name=self.config["storageAccountName"],
81 fs_name=self.config["storageFileShareName"],
82 ),
83 credential=credential,
84 token_intent="backup",
85 )
86 return self._share_client
88 def download(
89 self,
90 params: dict,
91 remote_path: str,
92 local_path: str,
93 recursive: bool = True,
94 ) -> None:
95 super().download(params, remote_path, local_path, recursive)
96 dir_client = self._get_share_client().get_directory_client(remote_path)
97 if dir_client.exists():
98 os.makedirs(local_path, exist_ok=True)
99 for content in dir_client.list_directories_and_files():
100 name = content["name"]
101 local_target = f"{local_path}/{name}"
102 remote_target = f"{remote_path}/{name}"
103 if recursive or not content["is_directory"]:
104 self.download(params, remote_target, local_target, recursive)
105 else: # Must be a file
106 # Ensure parent folders exist
107 folder, _ = os.path.split(local_path)
108 os.makedirs(folder, exist_ok=True)
109 file_client = self._get_share_client().get_file_client(remote_path)
110 try:
111 data = file_client.download_file()
112 with open(local_path, "wb") as output_file:
113 _LOG.debug("Download file: %s -> %s", remote_path, local_path)
114 data.readinto(output_file)
115 except ResourceNotFoundError as ex:
116 # Translate into non-Azure exception:
117 raise FileNotFoundError(f"Cannot download: {remote_path}") from ex
119 def upload(
120 self,
121 params: dict,
122 local_path: str,
123 remote_path: str,
124 recursive: bool = True,
125 ) -> None:
126 super().upload(params, local_path, remote_path, recursive)
127 self._upload(local_path, remote_path, recursive, set())
129 def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: set[str]) -> None:
130 """
131 Upload contents from a local path to an Azure file share. This method is called
132 from `.upload()` above. We need it to avoid exposing the `seen` parameter and to
133 make `.upload()` match the base class' virtual method.
135 Parameters
136 ----------
137 local_path : str
138 Path to the local directory to upload contents from, either a file or directory.
139 remote_path : str
140 Path in the remote file share to store the uploaded content to.
141 recursive : bool
142 If False, ignore the subdirectories;
143 if True (the default), upload the entire directory tree.
144 seen: set[str]
145 Helper set for keeping track of visited directories to break circular paths.
146 """
147 local_path = os.path.abspath(local_path)
148 if local_path in seen:
149 _LOG.warning("Loop in directories, skipping '%s'", local_path)
150 return
151 seen.add(local_path)
153 if os.path.isdir(local_path):
154 self._remote_makedirs(remote_path)
155 for entry in os.scandir(local_path):
156 name = entry.name
157 local_target = f"{local_path}/{name}"
158 remote_target = f"{remote_path}/{name}"
159 if recursive or not entry.is_dir():
160 self._upload(local_target, remote_target, recursive, seen)
161 else:
162 # Ensure parent folders exist
163 folder, _ = os.path.split(remote_path)
164 self._remote_makedirs(folder)
165 file_client = self._get_share_client().get_file_client(remote_path)
166 with open(local_path, "rb") as file_data:
167 _LOG.debug("Upload file: %s -> %s", local_path, remote_path)
168 file_client.upload_file(file_data)
170 def _remote_makedirs(self, remote_path: str) -> None:
171 """
172 Create remote directories for the entire path. Succeeds even some or all
173 directories along the path already exist.
175 Parameters
176 ----------
177 remote_path : str
178 Path in the remote file share to create.
179 """
180 path = ""
181 for folder in remote_path.replace("\\", "/").split("/"):
182 if not folder:
183 continue
184 path += folder + "/"
185 dir_client = self._get_share_client().get_directory_client(path)
186 if not dir_client.exists():
187 dir_client.create_directory()