Coverage for mlos_bench/mlos_bench/environments/local/local_env.py: 96%
126 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Scheduler-side benchmark environment to run scripts locally.
8TODO: Reference the script_env.py file for the base class.
9"""
11import json
12import logging
13import sys
14from collections.abc import Iterable, Mapping
15from contextlib import nullcontext
16from datetime import datetime
17from tempfile import TemporaryDirectory
18from types import TracebackType
19from typing import Any, Literal
21import pandas
23from mlos_bench.environments.base_environment import Environment
24from mlos_bench.environments.script_env import ScriptEnv
25from mlos_bench.environments.status import Status
26from mlos_bench.services.base_service import Service
27from mlos_bench.services.types.local_exec_type import SupportsLocalExec
28from mlos_bench.tunables.tunable import TunableValue
29from mlos_bench.tunables.tunable_groups import TunableGroups
30from mlos_bench.util import datetime_parser, path_join
32_LOG = logging.getLogger(__name__)
35class LocalEnv(ScriptEnv):
36 # pylint: disable=too-many-instance-attributes
37 """Scheduler-side Environment that runs scripts locally."""
39 def __init__( # pylint: disable=too-many-arguments
40 self,
41 *,
42 name: str,
43 config: dict,
44 global_config: dict | None = None,
45 tunables: TunableGroups | None = None,
46 service: Service | None = None,
47 ):
48 """
49 Create a new environment for local execution.
51 Parameters
52 ----------
53 name: str
54 Human-readable name of the environment.
55 config : dict
56 Free-format dictionary that contains the benchmark environment
57 configuration. Each config must have at least the "tunable_params"
58 and the "const_args" sections.
59 `LocalEnv` must also have at least some of the following parameters:
60 {setup, run, teardown, dump_params_file, read_results_file}
61 global_config : dict
62 Free-format dictionary of global parameters (e.g., security credentials)
63 to be mixed in into the "const_args" section of the local config.
64 tunables : TunableGroups
65 A collection of tunable parameters for *all* environments.
66 service: Service
67 An optional service object (e.g., providing methods to
68 deploy or reboot a VM, etc.).
69 """
70 super().__init__(
71 name=name,
72 config=config,
73 global_config=global_config,
74 tunables=tunables,
75 service=service,
76 )
78 assert self._service is not None and isinstance(
79 self._service, SupportsLocalExec
80 ), "LocalEnv requires a service that supports local execution"
81 self._local_exec_service: SupportsLocalExec = self._service
83 self._temp_dir: str | None = None
84 self._temp_dir_context: TemporaryDirectory | nullcontext | None = None
86 self._dump_params_file: str | None = self.config.get("dump_params_file")
87 self._dump_meta_file: str | None = self.config.get("dump_meta_file")
89 self._read_results_file: str | None = self.config.get("read_results_file")
90 self._read_telemetry_file: str | None = self.config.get("read_telemetry_file")
92 def __enter__(self) -> Environment:
93 assert self._temp_dir is None and self._temp_dir_context is None
94 self._temp_dir_context = self._local_exec_service.temp_dir_context(
95 self.config.get("temp_dir"),
96 )
97 self._temp_dir = self._temp_dir_context.__enter__()
98 return super().__enter__()
100 def __exit__(
101 self,
102 ex_type: type[BaseException] | None,
103 ex_val: BaseException | None,
104 ex_tb: TracebackType | None,
105 ) -> Literal[False]:
106 """Exit the context of the benchmarking environment."""
107 assert not (self._temp_dir is None or self._temp_dir_context is None)
108 self._temp_dir_context.__exit__(ex_type, ex_val, ex_tb)
109 self._temp_dir = None
110 self._temp_dir_context = None
111 return super().__exit__(ex_type, ex_val, ex_tb)
113 def setup(self, tunables: TunableGroups, global_config: dict | None = None) -> bool:
114 """
115 Check if the environment is ready and set up the application and benchmarks, if
116 necessary.
118 Parameters
119 ----------
120 tunables : TunableGroups
121 A collection of tunable OS and application parameters along with their
122 values. In a local environment these could be used to prepare a config
123 file on the scheduler prior to transferring it to the remote environment,
124 for instance.
125 global_config : dict
126 Free-format dictionary of global parameters of the environment
127 that are not used in the optimization process.
129 Returns
130 -------
131 is_success : bool
132 True if operation is successful, false otherwise.
133 """
134 if not super().setup(tunables, global_config):
135 return False
137 _LOG.info("Set up the environment locally: '%s' at %s", self, self._temp_dir)
138 assert self._temp_dir is not None
140 if self._dump_params_file:
141 fname = path_join(self._temp_dir, self._dump_params_file)
142 _LOG.debug("Dump tunables to file: %s", fname)
143 with open(fname, "w", encoding="utf-8") as fh_tunables:
144 # json.dump(self._params, fh_tunables) # Tunables *and* const_args
145 json.dump(self._tunable_params.get_param_values(), fh_tunables)
147 if self._dump_meta_file:
148 fname = path_join(self._temp_dir, self._dump_meta_file)
149 _LOG.debug("Dump tunables metadata to file: %s", fname)
150 with open(fname, "w", encoding="utf-8") as fh_meta:
151 json.dump(
152 {
153 tunable.name: tunable.meta
154 for (tunable, _group) in self._tunable_params
155 if tunable.meta
156 },
157 fh_meta,
158 )
160 if self._script_setup:
161 (return_code, _output) = self._local_exec(self._script_setup, self._temp_dir)
162 self._is_ready = bool(return_code == 0)
163 else:
164 self._is_ready = True
166 return self._is_ready
168 def run(self) -> tuple[Status, datetime, dict[str, TunableValue] | None]:
169 """
170 Run a script in the local scheduler environment.
172 Returns
173 -------
174 (status, timestamp, output) : (Status, datetime.datetime, dict)
175 3-tuple of (Status, timestamp, output) values, where `output` is a dict
176 with the results or None if the status is not COMPLETED.
177 If run script is a benchmark, then the score is usually expected to
178 be in the `score` field.
179 """
180 (status, timestamp, _) = result = super().run()
181 if not status.is_ready():
182 return result
184 assert self._temp_dir is not None
186 stdout_data: dict[str, TunableValue] = {}
187 if self._script_run:
188 (return_code, output) = self._local_exec(self._script_run, self._temp_dir)
189 if return_code != 0:
190 return (Status.FAILED, timestamp, None)
191 stdout_data = self._extract_stdout_results(output.get("stdout", ""))
193 # FIXME: We should not be assuming that the only output file type is a CSV.
194 if not self._read_results_file:
195 _LOG.debug("Not reading the data at: %s", self)
196 return (Status.SUCCEEDED, timestamp, stdout_data)
198 data = self._normalize_columns(
199 pandas.read_csv(
200 self._config_loader_service.resolve_path(
201 self._read_results_file,
202 extra_paths=[self._temp_dir],
203 ),
204 index_col=False,
205 )
206 )
208 _LOG.debug("Read data:\n%s", data)
209 if list(data.columns) == ["metric", "value"]:
210 _LOG.info(
211 "Local results have (metric,value) header and %d rows: assume long format",
212 len(data),
213 )
214 data = pandas.DataFrame([data.value.to_list()], columns=data.metric.to_list())
215 # Try to convert string metrics to numbers.
216 data = data.apply(
217 pandas.to_numeric,
218 errors="coerce",
219 ).fillna(data)
220 elif len(data) == 1:
221 _LOG.info("Local results have 1 row: assume wide format")
222 else:
223 raise ValueError(f"Invalid data format: {data}")
225 stdout_data.update(data.iloc[-1].to_dict())
226 _LOG.info("Local run complete: %s ::\n%s", self, stdout_data)
227 return (Status.SUCCEEDED, timestamp, stdout_data)
229 @staticmethod
230 def _normalize_columns(data: pandas.DataFrame) -> pandas.DataFrame:
231 """Strip trailing spaces from column names (Windows only)."""
232 # Windows cmd interpretation of > redirect symbols can leave trailing spaces in
233 # the final column, which leads to misnamed columns.
234 # For now, we simply strip trailing spaces from column names to account for that.
235 if sys.platform == "win32":
236 data.rename(str.rstrip, axis="columns", inplace=True)
237 return data
239 def status(self) -> tuple[Status, datetime, list[tuple[datetime, str, Any]]]:
241 (status, timestamp, _) = super().status()
242 if not (self._is_ready and self._read_telemetry_file):
243 return (status, timestamp, [])
245 assert self._temp_dir is not None
246 try:
247 fname = self._config_loader_service.resolve_path(
248 self._read_telemetry_file,
249 extra_paths=[self._temp_dir],
250 )
252 # TODO: Use the timestamp of the CSV file as our status timestamp?
254 # FIXME: We should not be assuming that the only output file type is a CSV.
256 data = self._normalize_columns(pandas.read_csv(fname, index_col=False))
257 data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local")
259 expected_col_names = ["timestamp", "metric", "value"]
260 if len(data.columns) != len(expected_col_names):
261 raise ValueError(f"Telemetry data must have columns {expected_col_names}")
263 if list(data.columns) != expected_col_names:
264 # Assume no header - this is ok for telemetry data.
265 data = pandas.read_csv(fname, index_col=False, names=expected_col_names)
266 data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local")
268 except FileNotFoundError as ex:
269 _LOG.warning("Telemetry CSV file not found: %s :: %s", self._read_telemetry_file, ex)
270 return (status, timestamp, [])
272 _LOG.debug("Read telemetry data:\n%s", data)
273 col_dtypes: Mapping[int, type] = {0: datetime}
274 return (
275 status,
276 timestamp,
277 [
278 (pandas.Timestamp(ts).to_pydatetime(), metric, value)
279 for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes)
280 ],
281 )
283 def teardown(self) -> None:
284 """Clean up the local environment."""
285 if self._script_teardown:
286 _LOG.info("Local teardown: %s", self)
287 (return_code, _output) = self._local_exec(self._script_teardown)
288 _LOG.info("Local teardown complete: %s :: %s", self, return_code)
289 super().teardown()
291 def _local_exec(self, script: Iterable[str], cwd: str | None = None) -> tuple[int, dict]:
292 """
293 Execute a script locally in the scheduler environment.
295 Parameters
296 ----------
297 script : Iterable[str]
298 Lines of the script to run locally.
299 Treat every line as a separate command to run.
300 cwd : str | None
301 Work directory to run the script at.
303 Returns
304 -------
305 (return_code, output) : (int, dict)
306 Return code of the script and a dict with stdout/stderr. Return code = 0 if successful.
307 """
308 env_params = self._get_env_params()
309 _LOG.info("Run script locally on: %s at %s with env %s", self, cwd, env_params)
310 (return_code, stdout, stderr) = self._local_exec_service.local_exec(
311 script,
312 env=env_params,
313 cwd=cwd,
314 )
315 if return_code != 0:
316 _LOG.warning("ERROR: Local script returns code %d stderr:\n%s", return_code, stderr)
317 return (return_code, {"stdout": stdout, "stderr": stderr})