Coverage for mlos_bench/mlos_bench/environments/base_environment.py: 93%
148 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""A hierarchy of benchmark environments."""
7import abc
8import json
9import logging
10from collections.abc import Iterable, Sequence
11from contextlib import AbstractContextManager as ContextManager
12from datetime import datetime
13from types import TracebackType
14from typing import TYPE_CHECKING, Any, Literal
16from pytz import UTC
18from mlos_bench.config.schemas import ConfigSchema
19from mlos_bench.dict_templater import DictTemplater
20from mlos_bench.environments.status import Status
21from mlos_bench.services.base_service import Service
22from mlos_bench.tunables.tunable import TunableValue
23from mlos_bench.tunables.tunable_groups import TunableGroups
24from mlos_bench.util import instantiate_from_config, merge_parameters
26if TYPE_CHECKING:
27 from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
29_LOG = logging.getLogger(__name__)
32class Environment(ContextManager, metaclass=abc.ABCMeta):
33 # pylint: disable=too-many-instance-attributes
34 """An abstract base of all benchmark environments."""
36 # Should be provided by the runtime.
37 _COMMON_CONST_ARGS = {
38 "trial_runner_id",
39 }
40 _COMMON_REQ_ARGS = {
41 "experiment_id",
42 "trial_id",
43 }
45 @classmethod
46 def new( # pylint: disable=too-many-arguments
47 cls,
48 *,
49 env_name: str,
50 class_name: str,
51 config: dict,
52 global_config: dict | None = None,
53 tunables: TunableGroups | None = None,
54 service: Service | None = None,
55 ) -> "Environment":
56 """
57 Factory method for a new environment with a given config.
59 Parameters
60 ----------
61 env_name: str
62 Human-readable name of the environment.
63 class_name: str
64 FQN of a Python class to instantiate, e.g.,
65 "mlos_bench.environments.remote.HostEnv".
66 Must be derived from the `Environment` class.
67 config : dict
68 Free-format dictionary that contains the benchmark environment
69 configuration. It will be passed as a constructor parameter of
70 the class specified by `name`.
71 global_config : dict
72 Free-format dictionary of global parameters (e.g., security credentials)
73 to be mixed in into the "const_args" section of the local config.
74 tunables : TunableGroups
75 A collection of groups of tunable parameters for all environments.
76 service: Service
77 An optional service object (e.g., providing methods to
78 deploy or reboot a VM/Host, etc.).
80 Returns
81 -------
82 env : Environment
83 An instance of the `Environment` class initialized with `config`.
84 """
85 assert issubclass(cls, Environment)
86 return instantiate_from_config(
87 cls,
88 class_name,
89 name=env_name,
90 config=config,
91 global_config=global_config,
92 tunables=tunables,
93 service=service,
94 )
96 def __init__( # pylint: disable=too-many-arguments
97 self,
98 *,
99 name: str,
100 config: dict,
101 global_config: dict | None = None,
102 tunables: TunableGroups | None = None,
103 service: Service | None = None,
104 ):
105 """
106 Create a new environment with a given config.
108 Parameters
109 ----------
110 name: str
111 Human-readable name of the environment.
112 config : dict
113 Free-format dictionary that contains the benchmark environment
114 configuration. Each config must have at least the "tunable_params"
115 and the "const_args" sections.
116 global_config : dict
117 Free-format dictionary of global parameters (e.g., security credentials)
118 to be mixed in into the "const_args" section of the local config.
119 tunables : TunableGroups
120 A collection of groups of tunable parameters for all environments.
121 service: Service
122 An optional service object (e.g., providing methods to
123 deploy or reboot a VM/Host, etc.).
124 """
125 global_config = global_config or {}
126 self._validate_json_config(config, name)
127 self.name = name
128 self.config = config
129 self._service = service
130 self._service_context: Service | None = None
131 self._is_ready = False
132 self._in_context = False
133 self._const_args: dict[str, TunableValue] = config.get("const_args", {})
135 # Make some usual runtime arguments available for tests.
136 for arg in self._COMMON_CONST_ARGS | self._COMMON_REQ_ARGS:
137 global_config.setdefault(arg, self._const_args.get(arg, None))
139 if _LOG.isEnabledFor(logging.DEBUG):
140 _LOG.debug(
141 "Environment: '%s' Service: %s",
142 name,
143 self._service.pprint() if self._service else None,
144 )
146 if tunables is None:
147 _LOG.warning(
148 (
149 "No tunables provided for %s. "
150 "Tunable inheritance across composite environments may be broken."
151 ),
152 name,
153 )
154 tunables = TunableGroups()
156 # TODO: add user docstrings for these in the module
157 groups = self._expand_groups(
158 config.get("tunable_params", []),
159 (global_config or {}).get("tunable_params_map", {}),
160 )
161 _LOG.debug("Tunable groups for: '%s' :: %s", name, groups)
163 self._tunable_params = tunables.subgroup(groups)
165 # If a parameter comes from the tunables, do not require it in the const_args or globals
166 req_args = (
167 set(config.get("required_args", [])) - self._tunable_params.get_param_values().keys()
168 )
169 req_args.update(self._COMMON_REQ_ARGS | self._COMMON_CONST_ARGS)
170 merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args)
171 self._const_args = self._expand_vars(self._const_args, global_config)
173 self._params = self._combine_tunables(self._tunable_params)
174 _LOG.debug("Parameters for '%s' :: %s", name, self._params)
176 if _LOG.isEnabledFor(logging.DEBUG):
177 _LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2))
179 def _validate_json_config(self, config: dict, name: str) -> None:
180 """Reconstructs a basic json config that this class might have been instantiated
181 from in order to validate configs provided outside the file loading
182 mechanism.
183 """
184 json_config: dict = {
185 "class": self.__class__.__module__ + "." + self.__class__.__name__,
186 }
187 if name:
188 json_config["name"] = name
189 if config:
190 json_config["config"] = config
191 ConfigSchema.ENVIRONMENT.validate(json_config)
193 @staticmethod
194 def _expand_groups(
195 groups: Iterable[str],
196 groups_exp: dict[str, str | Sequence[str]],
197 ) -> list[str]:
198 """
199 Expand `$tunable_group` into actual names of the tunable groups.
201 Parameters
202 ----------
203 groups : list[str]
204 Names of the groups of tunables, maybe with `$` prefix (subject to expansion).
205 groups_exp : dict
206 A dictionary that maps dollar variables for tunable groups to the lists
207 of actual tunable groups IDs.
209 Returns
210 -------
211 groups : list[str]
212 A flat list of tunable groups IDs for the environment.
213 """
214 res: list[str] = []
215 for grp in groups:
216 if grp[:1] == "$":
217 tunable_group_name = grp[1:]
218 if tunable_group_name not in groups_exp:
219 raise KeyError(
220 f"Expected tunable group name ${tunable_group_name} "
221 "undefined in {groups_exp}"
222 )
223 add_groups = groups_exp[tunable_group_name]
224 res += [add_groups] if isinstance(add_groups, str) else add_groups
225 else:
226 res.append(grp)
227 return res
229 @staticmethod
230 def _expand_vars(
231 params: dict[str, TunableValue],
232 global_config: dict[str, TunableValue],
233 ) -> dict:
234 """Expand `$var` into actual values of the variables."""
235 return DictTemplater(params).expand_vars(extra_source_dict=global_config)
237 @property
238 def _config_loader_service(self) -> "SupportsConfigLoading":
239 assert self._service is not None
240 return self._service.config_loader_service
242 def __enter__(self) -> "Environment":
243 """Enter the environment's benchmarking context."""
244 _LOG.debug("Environment START :: %s", self)
245 assert not self._in_context
246 if self._service:
247 self._service_context = self._service.__enter__()
248 self._in_context = True
249 return self
251 def __exit__(
252 self,
253 ex_type: type[BaseException] | None,
254 ex_val: BaseException | None,
255 ex_tb: TracebackType | None,
256 ) -> Literal[False]:
257 """Exit the context of the benchmarking environment."""
258 ex_throw = None
259 if ex_val is None:
260 _LOG.debug("Environment END :: %s", self)
261 else:
262 assert ex_type and ex_val
263 _LOG.warning("Environment END :: %s", self, exc_info=(ex_type, ex_val, ex_tb))
264 assert self._in_context
265 if self._service_context:
266 try:
267 self._service_context.__exit__(ex_type, ex_val, ex_tb)
268 # pylint: disable=broad-exception-caught
269 except Exception as ex:
270 _LOG.error("Exception while exiting Service context '%s': %s", self._service, ex)
271 ex_throw = ex
272 finally:
273 self._service_context = None
274 self._in_context = False
275 if ex_throw:
276 raise ex_throw
277 return False # Do not suppress exceptions
279 def __str__(self) -> str:
280 return self.name
282 def __repr__(self) -> str:
283 return f"{self.__class__.__name__} :: '{self.name}'"
285 def pprint(self, indent: int = 4, level: int = 0) -> str:
286 """
287 Pretty-print the environment configuration. For composite environments, print
288 all children environments as well.
290 Parameters
291 ----------
292 indent : int
293 Number of spaces to indent the output. Default is 4.
294 level : int
295 Current level of indentation. Default is 0.
297 Returns
298 -------
299 pretty : str
300 Pretty-printed environment configuration.
301 Default output is the same as `__repr__`.
302 """
303 return f'{" " * indent * level}{repr(self)}'
305 def _combine_tunables(self, tunables: TunableGroups) -> dict[str, TunableValue]:
306 """
307 Plug tunable values into the base config. If the tunable group is unknown,
308 ignore it (it might belong to another environment). This method should never
309 mutate the original config or the tunables.
311 Parameters
312 ----------
313 tunables : TunableGroups
314 A collection of groups of tunable parameters
315 along with the parameters' values.
317 Returns
318 -------
319 params : dict[str, Union[int, float, str]]
320 Free-format dictionary that contains the new environment configuration.
321 """
322 return tunables.get_param_values(
323 group_names=list(self._tunable_params.get_covariant_group_names()),
324 into_params=self._const_args.copy(),
325 )
327 @property
328 def tunable_params(self) -> TunableGroups:
329 """
330 Get the configuration space of the given environment.
332 Returns
333 -------
334 tunables : TunableGroups
335 A collection of covariant groups of tunable parameters.
336 """
337 return self._tunable_params
339 @property
340 def const_args(self) -> dict[str, TunableValue]:
341 """
342 Get the constant arguments for this Environment.
344 Returns
345 -------
346 parameters : Dict[str, TunableValue]
347 Key/value pairs of all environment const_args parameters.
348 """
349 return self._const_args.copy()
351 @property
352 def parameters(self) -> dict[str, TunableValue]:
353 """
354 Key/value pairs of all environment parameters (i.e., `const_args` and
355 `tunable_params`). Note that before `.setup()` is called, all tunables will be
356 set to None.
358 Returns
359 -------
360 parameters : dict[str, TunableValue]
361 Key/value pairs of all environment parameters
362 (i.e., `const_args` and `tunable_params`).
363 """
364 return self._params.copy()
366 def setup(self, tunables: TunableGroups, global_config: dict | None = None) -> bool:
367 """
368 Set up a new benchmark environment, if necessary. This method must be
369 idempotent, i.e., calling it several times in a row should be equivalent to a
370 single call.
372 Parameters
373 ----------
374 tunables : TunableGroups
375 A collection of tunable parameters along with their values.
376 global_config : dict
377 Free-format dictionary of global parameters of the environment
378 that are not used in the optimization process.
380 Returns
381 -------
382 is_success : bool
383 True if operation is successful, false otherwise.
384 """
385 _LOG.info("Setup %s :: %s", self, tunables)
386 assert isinstance(tunables, TunableGroups)
388 # Make sure we create a context before invoking setup/run/status/teardown
389 assert self._in_context
391 # Assign new values to the environment's tunable parameters:
392 groups = list(self._tunable_params.get_covariant_group_names())
393 self._tunable_params.assign(tunables.get_param_values(groups))
395 # Write to the log whether the environment needs to be reset.
396 # (Derived classes still have to check `self._tunable_params.is_updated()`).
397 is_updated = self._tunable_params.is_updated()
398 if _LOG.isEnabledFor(logging.DEBUG):
399 _LOG.debug(
400 "Env '%s': Tunable groups reset = %s :: %s",
401 self,
402 is_updated,
403 {
404 name: self._tunable_params.is_updated([name])
405 for name in self._tunable_params.get_covariant_group_names()
406 },
407 )
408 else:
409 _LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated)
411 # Combine tunables, const_args, and global config into `self._params`:
412 self._params = self._combine_tunables(tunables)
413 merge_parameters(dest=self._params, source=global_config)
415 if _LOG.isEnabledFor(logging.DEBUG):
416 _LOG.debug("Combined parameters:\n%s", json.dumps(self._params, indent=2))
418 return True
420 def teardown(self) -> None:
421 """
422 Tear down the benchmark environment.
424 This method must be idempotent, i.e., calling it several times in a row should
425 be equivalent to a single call.
426 """
427 _LOG.info("Teardown %s", self)
428 # Make sure we create a context before invoking setup/run/status/teardown
429 assert self._in_context
430 self._is_ready = False
432 def run(self) -> tuple[Status, datetime, dict[str, TunableValue] | None]:
433 """
434 Execute the run script for this environment.
436 For instance, this may start a new experiment, download results, reconfigure
437 the environment, etc. Details are configurable via the environment config.
439 Returns
440 -------
441 (status, timestamp, output) : (Status, datetime.datetime, dict)
442 3-tuple of (Status, timestamp, output) values, where `output` is a dict
443 with the results or None if the status is not COMPLETED.
444 If run script is a benchmark, then the score is usually expected to
445 be in the `score` field.
446 """
447 # Make sure we create a context before invoking setup/run/status/teardown
448 assert self._in_context
449 (status, timestamp, _) = self.status()
450 return (status, timestamp, None)
452 def status(self) -> tuple[Status, datetime, list[tuple[datetime, str, Any]]]:
453 """
454 Check the status of the benchmark environment.
456 Returns
457 -------
458 (benchmark_status, timestamp, telemetry) : (Status, datetime.datetime, list)
459 3-tuple of (benchmark status, timestamp, telemetry) values.
460 `timestamp` is UTC time stamp of the status; it's current time by default.
461 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets.
462 """
463 # Make sure we create a context before invoking setup/run/status/teardown
464 assert self._in_context
465 timestamp = datetime.now(UTC)
466 if self._is_ready:
467 return (Status.READY, timestamp, [])
468 _LOG.warning("Environment not ready: %s", self)
469 return (Status.PENDING, timestamp, [])