Coverage for mlos_bench/mlos_bench/environments/local/local_fileshare_env.py: 88%
56 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Scheduler-side Environment to run scripts locally and upload/download data to the
6shared storage.
7"""
9import logging
10from collections.abc import Generator, Iterable, Mapping
11from datetime import datetime
12from string import Template
13from typing import Any
15from mlos_bench.environments.local.local_env import LocalEnv
16from mlos_bench.environments.status import Status
17from mlos_bench.services.base_service import Service
18from mlos_bench.services.types.fileshare_type import SupportsFileShareOps
19from mlos_bench.services.types.local_exec_type import SupportsLocalExec
20from mlos_bench.tunables.tunable import TunableValue
21from mlos_bench.tunables.tunable_groups import TunableGroups
23_LOG = logging.getLogger(__name__)
26class LocalFileShareEnv(LocalEnv):
27 """Scheduler-side Environment that runs scripts locally and uploads/downloads data
28 to the shared file storage.
29 """
31 def __init__( # pylint: disable=too-many-arguments
32 self,
33 *,
34 name: str,
35 config: dict,
36 global_config: dict | None = None,
37 tunables: TunableGroups | None = None,
38 service: Service | None = None,
39 ):
40 """
41 Create a new application environment with a given config.
43 Parameters
44 ----------
45 name: str
46 Human-readable name of the environment.
47 config : dict
48 Free-format dictionary that contains the benchmark environment
49 configuration. Each config must have at least the "tunable_params"
50 and the "const_args" sections.
51 `LocalFileShareEnv` must also have at least some of the following
52 parameters: {setup, upload, run, download, teardown,
53 dump_params_file, read_results_file}
54 global_config : dict
55 Free-format dictionary of global parameters (e.g., security credentials)
56 to be mixed in into the "const_args" section of the local config.
57 tunables : TunableGroups
58 A collection of tunable parameters for *all* environments.
59 service: Service
60 An optional service object (e.g., providing methods to
61 deploy or reboot a VM, etc.).
62 """
63 super().__init__(
64 name=name,
65 config=config,
66 global_config=global_config,
67 tunables=tunables,
68 service=service,
69 )
71 assert self._service is not None and isinstance(
72 self._service, SupportsLocalExec
73 ), "LocalEnv requires a service that supports local execution"
74 self._local_exec_service: SupportsLocalExec = self._service
76 assert self._service is not None and isinstance(
77 self._service, SupportsFileShareOps
78 ), "LocalEnv requires a service that supports file upload/download operations"
79 self._file_share_service: SupportsFileShareOps = self._service
81 self._upload = self._template_from_to("upload")
82 self._download = self._template_from_to("download")
84 def _template_from_to(self, config_key: str) -> list[tuple[Template, Template]]:
85 """Convert a list of {"from": "...", "to": "..."} to a list of pairs of
86 string.Template objects so that we can plug in self._params into it later.
87 """
88 return [(Template(d["from"]), Template(d["to"])) for d in self.config.get(config_key, [])]
90 @staticmethod
91 def _expand(
92 from_to: Iterable[tuple[Template, Template]],
93 params: Mapping[str, TunableValue],
94 ) -> Generator[tuple[str, str]]:
95 """
96 Substitute $var parameters in from/to path templates.
98 Return a generator of (str, str) pairs of paths.
99 """
100 return (
101 (path_from.safe_substitute(params), path_to.safe_substitute(params))
102 for (path_from, path_to) in from_to
103 )
105 def setup(self, tunables: TunableGroups, global_config: dict | None = None) -> bool:
106 """
107 Run setup scripts locally and upload the scripts and data to the shared storage.
109 Parameters
110 ----------
111 tunables : TunableGroups
112 A collection of tunable OS and application parameters along with their
113 values. In a local environment these could be used to prepare a config
114 file on the scheduler prior to transferring it to the remote environment,
115 for instance.
116 global_config : dict
117 Free-format dictionary of global parameters of the environment
118 that are not used in the optimization process.
120 Returns
121 -------
122 is_success : bool
123 True if operation is successful, false otherwise.
124 """
125 self._is_ready = super().setup(tunables, global_config)
126 if self._is_ready:
127 assert self._temp_dir is not None
128 params = self._get_env_params(restrict=False)
129 params["PWD"] = self._temp_dir
130 for path_from, path_to in self._expand(self._upload, params):
131 self._file_share_service.upload(
132 self._params,
133 self._config_loader_service.resolve_path(
134 path_from,
135 extra_paths=[self._temp_dir],
136 ),
137 path_to,
138 )
139 return self._is_ready
141 def _download_files(self, ignore_missing: bool = False) -> None:
142 """
143 Download files from the shared storage.
145 Parameters
146 ----------
147 ignore_missing : bool
148 If True, raise an exception when some file cannot be downloaded.
149 If False, proceed with downloading other files and log a warning.
150 """
151 assert self._temp_dir is not None
152 params = self._get_env_params(restrict=False)
153 params["PWD"] = self._temp_dir
154 for path_from, path_to in self._expand(self._download, params):
155 try:
156 self._file_share_service.download(
157 self._params,
158 path_from,
159 self._config_loader_service.resolve_path(
160 path_to,
161 extra_paths=[self._temp_dir],
162 ),
163 )
164 except FileNotFoundError as ex:
165 _LOG.warning("Cannot download: %s", path_from)
166 if not ignore_missing:
167 raise ex
168 except Exception as ex:
169 _LOG.exception("Cannot download %s to %s", path_from, path_to)
170 raise ex
172 def run(self) -> tuple[Status, datetime, dict[str, TunableValue] | None]:
173 """
174 Download benchmark results from the shared storage and run post-processing
175 scripts locally.
177 Returns
178 -------
179 (status, timestamp, output) : (Status, datetime.datetime, dict)
180 3-tuple of (Status, timestamp, output) values, where `output` is a dict
181 with the results or None if the status is not COMPLETED.
182 If run script is a benchmark, then the score is usually expected to
183 be in the `score` field.
184 """
185 self._download_files()
186 return super().run()
188 def status(self) -> tuple[Status, datetime, list[tuple[datetime, str, Any]]]:
189 self._download_files(ignore_missing=True)
190 return super().status()