Coverage for mlos_bench/mlos_bench/environments/local/local_env.py: 88%
126 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Scheduler-side benchmark environment to run scripts locally."""
7import json
8import logging
9import sys
10from contextlib import nullcontext
11from datetime import datetime
12from tempfile import TemporaryDirectory
13from types import TracebackType
14from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, Union
16import pandas
17from typing_extensions import Literal
19from mlos_bench.environments.base_environment import Environment
20from mlos_bench.environments.script_env import ScriptEnv
21from mlos_bench.environments.status import Status
22from mlos_bench.services.base_service import Service
23from mlos_bench.services.types.local_exec_type import SupportsLocalExec
24from mlos_bench.tunables.tunable import TunableValue
25from mlos_bench.tunables.tunable_groups import TunableGroups
26from mlos_bench.util import datetime_parser, path_join
28_LOG = logging.getLogger(__name__)
31class LocalEnv(ScriptEnv):
32 # pylint: disable=too-many-instance-attributes
33 """Scheduler-side Environment that runs scripts locally."""
35 def __init__( # pylint: disable=too-many-arguments
36 self,
37 *,
38 name: str,
39 config: dict,
40 global_config: Optional[dict] = None,
41 tunables: Optional[TunableGroups] = None,
42 service: Optional[Service] = None,
43 ):
44 """
45 Create a new environment for local execution.
47 Parameters
48 ----------
49 name: str
50 Human-readable name of the environment.
51 config : dict
52 Free-format dictionary that contains the benchmark environment
53 configuration. Each config must have at least the "tunable_params"
54 and the "const_args" sections.
55 `LocalEnv` must also have at least some of the following parameters:
56 {setup, run, teardown, dump_params_file, read_results_file}
57 global_config : dict
58 Free-format dictionary of global parameters (e.g., security credentials)
59 to be mixed in into the "const_args" section of the local config.
60 tunables : TunableGroups
61 A collection of tunable parameters for *all* environments.
62 service: Service
63 An optional service object (e.g., providing methods to
64 deploy or reboot a VM, etc.).
65 """
66 super().__init__(
67 name=name,
68 config=config,
69 global_config=global_config,
70 tunables=tunables,
71 service=service,
72 )
74 assert self._service is not None and isinstance(
75 self._service, SupportsLocalExec
76 ), "LocalEnv requires a service that supports local execution"
77 self._local_exec_service: SupportsLocalExec = self._service
79 self._temp_dir: Optional[str] = None
80 self._temp_dir_context: Union[TemporaryDirectory, nullcontext, None] = None
82 self._dump_params_file: Optional[str] = self.config.get("dump_params_file")
83 self._dump_meta_file: Optional[str] = self.config.get("dump_meta_file")
85 self._read_results_file: Optional[str] = self.config.get("read_results_file")
86 self._read_telemetry_file: Optional[str] = self.config.get("read_telemetry_file")
88 def __enter__(self) -> Environment:
89 assert self._temp_dir is None and self._temp_dir_context is None
90 self._temp_dir_context = self._local_exec_service.temp_dir_context(
91 self.config.get("temp_dir"),
92 )
93 self._temp_dir = self._temp_dir_context.__enter__()
94 return super().__enter__()
96 def __exit__(
97 self,
98 ex_type: Optional[Type[BaseException]],
99 ex_val: Optional[BaseException],
100 ex_tb: Optional[TracebackType],
101 ) -> Literal[False]:
102 """Exit the context of the benchmarking environment."""
103 assert not (self._temp_dir is None or self._temp_dir_context is None)
104 self._temp_dir_context.__exit__(ex_type, ex_val, ex_tb)
105 self._temp_dir = None
106 self._temp_dir_context = None
107 return super().__exit__(ex_type, ex_val, ex_tb)
109 def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
110 """
111 Check if the environment is ready and set up the application and benchmarks, if
112 necessary.
114 Parameters
115 ----------
116 tunables : TunableGroups
117 A collection of tunable OS and application parameters along with their
118 values. In a local environment these could be used to prepare a config
119 file on the scheduler prior to transferring it to the remote environment,
120 for instance.
121 global_config : dict
122 Free-format dictionary of global parameters of the environment
123 that are not used in the optimization process.
125 Returns
126 -------
127 is_success : bool
128 True if operation is successful, false otherwise.
129 """
130 if not super().setup(tunables, global_config):
131 return False
133 _LOG.info("Set up the environment locally: '%s' at %s", self, self._temp_dir)
134 assert self._temp_dir is not None
136 if self._dump_params_file:
137 fname = path_join(self._temp_dir, self._dump_params_file)
138 _LOG.debug("Dump tunables to file: %s", fname)
139 with open(fname, "wt", encoding="utf-8") as fh_tunables:
140 # json.dump(self._params, fh_tunables) # Tunables *and* const_args
141 json.dump(self._tunable_params.get_param_values(), fh_tunables)
143 if self._dump_meta_file:
144 fname = path_join(self._temp_dir, self._dump_meta_file)
145 _LOG.debug("Dump tunables metadata to file: %s", fname)
146 with open(fname, "wt", encoding="utf-8") as fh_meta:
147 json.dump(
148 {
149 tunable.name: tunable.meta
150 for (tunable, _group) in self._tunable_params
151 if tunable.meta
152 },
153 fh_meta,
154 )
156 if self._script_setup:
157 (return_code, _output) = self._local_exec(self._script_setup, self._temp_dir)
158 self._is_ready = bool(return_code == 0)
159 else:
160 self._is_ready = True
162 return self._is_ready
164 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
165 """
166 Run a script in the local scheduler environment.
168 Returns
169 -------
170 (status, timestamp, output) : (Status, datetime, dict)
171 3-tuple of (Status, timestamp, output) values, where `output` is a dict
172 with the results or None if the status is not COMPLETED.
173 If run script is a benchmark, then the score is usually expected to
174 be in the `score` field.
175 """
176 (status, timestamp, _) = result = super().run()
177 if not status.is_ready():
178 return result
180 assert self._temp_dir is not None
182 stdout_data: Dict[str, TunableValue] = {}
183 if self._script_run:
184 (return_code, output) = self._local_exec(self._script_run, self._temp_dir)
185 if return_code != 0:
186 return (Status.FAILED, timestamp, None)
187 stdout_data = self._extract_stdout_results(output.get("stdout", ""))
189 # FIXME: We should not be assuming that the only output file type is a CSV.
190 if not self._read_results_file:
191 _LOG.debug("Not reading the data at: %s", self)
192 return (Status.SUCCEEDED, timestamp, stdout_data)
194 data = self._normalize_columns(
195 pandas.read_csv(
196 self._config_loader_service.resolve_path(
197 self._read_results_file,
198 extra_paths=[self._temp_dir],
199 ),
200 index_col=False,
201 )
202 )
204 _LOG.debug("Read data:\n%s", data)
205 if list(data.columns) == ["metric", "value"]:
206 _LOG.info(
207 "Local results have (metric,value) header and %d rows: assume long format",
208 len(data),
209 )
210 data = pandas.DataFrame([data.value.to_list()], columns=data.metric.to_list())
211 # Try to convert string metrics to numbers.
212 data = data.apply( # type: ignore[assignment] # (false positive)
213 pandas.to_numeric,
214 errors="coerce",
215 ).fillna(data)
216 elif len(data) == 1:
217 _LOG.info("Local results have 1 row: assume wide format")
218 else:
219 raise ValueError(f"Invalid data format: {data}")
221 stdout_data.update(data.iloc[-1].to_dict())
222 _LOG.info("Local run complete: %s ::\n%s", self, stdout_data)
223 return (Status.SUCCEEDED, timestamp, stdout_data)
225 @staticmethod
226 def _normalize_columns(data: pandas.DataFrame) -> pandas.DataFrame:
227 """Strip trailing spaces from column names (Windows only)."""
228 # Windows cmd interpretation of > redirect symbols can leave trailing spaces in
229 # the final column, which leads to misnamed columns.
230 # For now, we simply strip trailing spaces from column names to account for that.
231 if sys.platform == "win32":
232 data.rename(str.rstrip, axis="columns", inplace=True)
233 return data
235 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]:
237 (status, timestamp, _) = super().status()
238 if not (self._is_ready and self._read_telemetry_file):
239 return (status, timestamp, [])
241 assert self._temp_dir is not None
242 try:
243 fname = self._config_loader_service.resolve_path(
244 self._read_telemetry_file,
245 extra_paths=[self._temp_dir],
246 )
248 # TODO: Use the timestamp of the CSV file as our status timestamp?
250 # FIXME: We should not be assuming that the only output file type is a CSV.
252 data = self._normalize_columns(pandas.read_csv(fname, index_col=False))
253 data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local")
255 expected_col_names = ["timestamp", "metric", "value"]
256 if len(data.columns) != len(expected_col_names):
257 raise ValueError(f"Telemetry data must have columns {expected_col_names}")
259 if list(data.columns) != expected_col_names:
260 # Assume no header - this is ok for telemetry data.
261 data = pandas.read_csv(fname, index_col=False, names=expected_col_names)
262 data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local")
264 except FileNotFoundError as ex:
265 _LOG.warning("Telemetry CSV file not found: %s :: %s", self._read_telemetry_file, ex)
266 return (status, timestamp, [])
268 _LOG.debug("Read telemetry data:\n%s", data)
269 col_dtypes: Mapping[int, Type] = {0: datetime}
270 return (
271 status,
272 timestamp,
273 [
274 (pandas.Timestamp(ts).to_pydatetime(), metric, value)
275 for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes)
276 ],
277 )
279 def teardown(self) -> None:
280 """Clean up the local environment."""
281 if self._script_teardown:
282 _LOG.info("Local teardown: %s", self)
283 (return_code, _output) = self._local_exec(self._script_teardown)
284 _LOG.info("Local teardown complete: %s :: %s", self, return_code)
285 super().teardown()
287 def _local_exec(self, script: Iterable[str], cwd: Optional[str] = None) -> Tuple[int, dict]:
288 """
289 Execute a script locally in the scheduler environment.
291 Parameters
292 ----------
293 script : Iterable[str]
294 Lines of the script to run locally.
295 Treat every line as a separate command to run.
296 cwd : Optional[str]
297 Work directory to run the script at.
299 Returns
300 -------
301 (return_code, output) : (int, dict)
302 Return code of the script and a dict with stdout/stderr. Return code = 0 if successful.
303 """
304 env_params = self._get_env_params()
305 _LOG.info("Run script locally on: %s at %s with env %s", self, cwd, env_params)
306 (return_code, stdout, stderr) = self._local_exec_service.local_exec(
307 script,
308 env=env_params,
309 cwd=cwd,
310 )
311 if return_code != 0:
312 _LOG.warning("ERROR: Local script returns code %d stderr:\n%s", return_code, stderr)
313 return (return_code, {"stdout": stdout, "stderr": stderr})