Coverage for mlos_bench/mlos_bench/environments/base_environment.py: 93%
137 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""A hierarchy of benchmark environments."""
7import abc
8import json
9import logging
10from datetime import datetime
11from types import TracebackType
12from typing import (
13 TYPE_CHECKING,
14 Any,
15 Dict,
16 Iterable,
17 List,
18 Literal,
19 Optional,
20 Sequence,
21 Tuple,
22 Type,
23 Union,
24)
26from pytz import UTC
28from mlos_bench.config.schemas import ConfigSchema
29from mlos_bench.dict_templater import DictTemplater
30from mlos_bench.environments.status import Status
31from mlos_bench.services.base_service import Service
32from mlos_bench.tunables.tunable import TunableValue
33from mlos_bench.tunables.tunable_groups import TunableGroups
34from mlos_bench.util import instantiate_from_config, merge_parameters
36if TYPE_CHECKING:
37 from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
39_LOG = logging.getLogger(__name__)
42class Environment(metaclass=abc.ABCMeta):
43 # pylint: disable=too-many-instance-attributes
44 """An abstract base of all benchmark environments."""
46 @classmethod
47 def new( # pylint: disable=too-many-arguments
48 cls,
49 *,
50 env_name: str,
51 class_name: str,
52 config: dict,
53 global_config: Optional[dict] = None,
54 tunables: Optional[TunableGroups] = None,
55 service: Optional[Service] = None,
56 ) -> "Environment":
57 """
58 Factory method for a new environment with a given config.
60 Parameters
61 ----------
62 env_name: str
63 Human-readable name of the environment.
64 class_name: str
65 FQN of a Python class to instantiate, e.g.,
66 "mlos_bench.environments.remote.HostEnv".
67 Must be derived from the `Environment` class.
68 config : dict
69 Free-format dictionary that contains the benchmark environment
70 configuration. It will be passed as a constructor parameter of
71 the class specified by `name`.
72 global_config : dict
73 Free-format dictionary of global parameters (e.g., security credentials)
74 to be mixed in into the "const_args" section of the local config.
75 tunables : TunableGroups
76 A collection of groups of tunable parameters for all environments.
77 service: Service
78 An optional service object (e.g., providing methods to
79 deploy or reboot a VM/Host, etc.).
81 Returns
82 -------
83 env : Environment
84 An instance of the `Environment` class initialized with `config`.
85 """
86 assert issubclass(cls, Environment)
87 return instantiate_from_config(
88 cls,
89 class_name,
90 name=env_name,
91 config=config,
92 global_config=global_config,
93 tunables=tunables,
94 service=service,
95 )
97 def __init__( # pylint: disable=too-many-arguments
98 self,
99 *,
100 name: str,
101 config: dict,
102 global_config: Optional[dict] = None,
103 tunables: Optional[TunableGroups] = None,
104 service: Optional[Service] = None,
105 ):
106 """
107 Create a new environment with a given config.
109 Parameters
110 ----------
111 name: str
112 Human-readable name of the environment.
113 config : dict
114 Free-format dictionary that contains the benchmark environment
115 configuration. Each config must have at least the "tunable_params"
116 and the "const_args" sections.
117 global_config : dict
118 Free-format dictionary of global parameters (e.g., security credentials)
119 to be mixed in into the "const_args" section of the local config.
120 tunables : TunableGroups
121 A collection of groups of tunable parameters for all environments.
122 service: Service
123 An optional service object (e.g., providing methods to
124 deploy or reboot a VM/Host, etc.).
125 """
126 self._validate_json_config(config, name)
127 self.name = name
128 self.config = config
129 self._service = service
130 self._service_context: Optional[Service] = None
131 self._is_ready = False
132 self._in_context = False
133 self._const_args: Dict[str, TunableValue] = config.get("const_args", {})
135 if _LOG.isEnabledFor(logging.DEBUG):
136 _LOG.debug(
137 "Environment: '%s' Service: %s",
138 name,
139 self._service.pprint() if self._service else None,
140 )
142 if tunables is None:
143 _LOG.warning(
144 (
145 "No tunables provided for %s. "
146 "Tunable inheritance across composite environments may be broken."
147 ),
148 name,
149 )
150 tunables = TunableGroups()
152 # TODO: add user docstrings for these in the module
153 groups = self._expand_groups(
154 config.get("tunable_params", []),
155 (global_config or {}).get("tunable_params_map", {}),
156 )
157 _LOG.debug("Tunable groups for: '%s' :: %s", name, groups)
159 self._tunable_params = tunables.subgroup(groups)
161 # If a parameter comes from the tunables, do not require it in the const_args or globals
162 req_args = set(config.get("required_args", [])) - set(
163 self._tunable_params.get_param_values().keys()
164 )
165 merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args)
166 self._const_args = self._expand_vars(self._const_args, global_config or {})
168 self._params = self._combine_tunables(self._tunable_params)
169 _LOG.debug("Parameters for '%s' :: %s", name, self._params)
171 if _LOG.isEnabledFor(logging.DEBUG):
172 _LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2))
174 def _validate_json_config(self, config: dict, name: str) -> None:
175 """Reconstructs a basic json config that this class might have been instantiated
176 from in order to validate configs provided outside the file loading
177 mechanism.
178 """
179 json_config: dict = {
180 "class": self.__class__.__module__ + "." + self.__class__.__name__,
181 }
182 if name:
183 json_config["name"] = name
184 if config:
185 json_config["config"] = config
186 ConfigSchema.ENVIRONMENT.validate(json_config)
188 @staticmethod
189 def _expand_groups(
190 groups: Iterable[str],
191 groups_exp: Dict[str, Union[str, Sequence[str]]],
192 ) -> List[str]:
193 """
194 Expand `$tunable_group` into actual names of the tunable groups.
196 Parameters
197 ----------
198 groups : List[str]
199 Names of the groups of tunables, maybe with `$` prefix (subject to expansion).
200 groups_exp : dict
201 A dictionary that maps dollar variables for tunable groups to the lists
202 of actual tunable groups IDs.
204 Returns
205 -------
206 groups : List[str]
207 A flat list of tunable groups IDs for the environment.
208 """
209 res: List[str] = []
210 for grp in groups:
211 if grp[:1] == "$":
212 tunable_group_name = grp[1:]
213 if tunable_group_name not in groups_exp:
214 raise KeyError(
215 (
216 f"Expected tunable group name ${tunable_group_name} "
217 "undefined in {groups_exp}"
218 )
219 )
220 add_groups = groups_exp[tunable_group_name]
221 res += [add_groups] if isinstance(add_groups, str) else add_groups
222 else:
223 res.append(grp)
224 return res
226 @staticmethod
227 def _expand_vars(
228 params: Dict[str, TunableValue],
229 global_config: Dict[str, TunableValue],
230 ) -> dict:
231 """Expand `$var` into actual values of the variables."""
232 return DictTemplater(params).expand_vars(extra_source_dict=global_config)
234 @property
235 def _config_loader_service(self) -> "SupportsConfigLoading":
236 assert self._service is not None
237 return self._service.config_loader_service
239 def __enter__(self) -> "Environment":
240 """Enter the environment's benchmarking context."""
241 _LOG.debug("Environment START :: %s", self)
242 assert not self._in_context
243 if self._service:
244 self._service_context = self._service.__enter__()
245 self._in_context = True
246 return self
248 def __exit__(
249 self,
250 ex_type: Optional[Type[BaseException]],
251 ex_val: Optional[BaseException],
252 ex_tb: Optional[TracebackType],
253 ) -> Literal[False]:
254 """Exit the context of the benchmarking environment."""
255 ex_throw = None
256 if ex_val is None:
257 _LOG.debug("Environment END :: %s", self)
258 else:
259 assert ex_type and ex_val
260 _LOG.warning("Environment END :: %s", self, exc_info=(ex_type, ex_val, ex_tb))
261 assert self._in_context
262 if self._service_context:
263 try:
264 self._service_context.__exit__(ex_type, ex_val, ex_tb)
265 # pylint: disable=broad-exception-caught
266 except Exception as ex:
267 _LOG.error("Exception while exiting Service context '%s': %s", self._service, ex)
268 ex_throw = ex
269 finally:
270 self._service_context = None
271 self._in_context = False
272 if ex_throw:
273 raise ex_throw
274 return False # Do not suppress exceptions
276 def __str__(self) -> str:
277 return self.name
279 def __repr__(self) -> str:
280 return f"{self.__class__.__name__} :: '{self.name}'"
282 def pprint(self, indent: int = 4, level: int = 0) -> str:
283 """
284 Pretty-print the environment configuration. For composite environments, print
285 all children environments as well.
287 Parameters
288 ----------
289 indent : int
290 Number of spaces to indent the output. Default is 4.
291 level : int
292 Current level of indentation. Default is 0.
294 Returns
295 -------
296 pretty : str
297 Pretty-printed environment configuration.
298 Default output is the same as `__repr__`.
299 """
300 return f'{" " * indent * level}{repr(self)}'
302 def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]:
303 """
304 Plug tunable values into the base config. If the tunable group is unknown,
305 ignore it (it might belong to another environment). This method should never
306 mutate the original config or the tunables.
308 Parameters
309 ----------
310 tunables : TunableGroups
311 A collection of groups of tunable parameters
312 along with the parameters' values.
314 Returns
315 -------
316 params : Dict[str, Union[int, float, str]]
317 Free-format dictionary that contains the new environment configuration.
318 """
319 return tunables.get_param_values(
320 group_names=list(self._tunable_params.get_covariant_group_names()),
321 into_params=self._const_args.copy(),
322 )
324 @property
325 def tunable_params(self) -> TunableGroups:
326 """
327 Get the configuration space of the given environment.
329 Returns
330 -------
331 tunables : TunableGroups
332 A collection of covariant groups of tunable parameters.
333 """
334 return self._tunable_params
336 @property
337 def parameters(self) -> Dict[str, TunableValue]:
338 """
339 Key/value pairs of all environment parameters (i.e., `const_args` and
340 `tunable_params`). Note that before `.setup()` is called, all tunables will be
341 set to None.
343 Returns
344 -------
345 parameters : Dict[str, TunableValue]
346 Key/value pairs of all environment parameters
347 (i.e., `const_args` and `tunable_params`).
348 """
349 return self._params
351 def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
352 """
353 Set up a new benchmark environment, if necessary. This method must be
354 idempotent, i.e., calling it several times in a row should be equivalent to a
355 single call.
357 Parameters
358 ----------
359 tunables : TunableGroups
360 A collection of tunable parameters along with their values.
361 global_config : dict
362 Free-format dictionary of global parameters of the environment
363 that are not used in the optimization process.
365 Returns
366 -------
367 is_success : bool
368 True if operation is successful, false otherwise.
369 """
370 _LOG.info("Setup %s :: %s", self, tunables)
371 assert isinstance(tunables, TunableGroups)
373 # Make sure we create a context before invoking setup/run/status/teardown
374 assert self._in_context
376 # Assign new values to the environment's tunable parameters:
377 groups = list(self._tunable_params.get_covariant_group_names())
378 self._tunable_params.assign(tunables.get_param_values(groups))
380 # Write to the log whether the environment needs to be reset.
381 # (Derived classes still have to check `self._tunable_params.is_updated()`).
382 is_updated = self._tunable_params.is_updated()
383 if _LOG.isEnabledFor(logging.DEBUG):
384 _LOG.debug(
385 "Env '%s': Tunable groups reset = %s :: %s",
386 self,
387 is_updated,
388 {
389 name: self._tunable_params.is_updated([name])
390 for name in self._tunable_params.get_covariant_group_names()
391 },
392 )
393 else:
394 _LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated)
396 # Combine tunables, const_args, and global config into `self._params`:
397 self._params = self._combine_tunables(tunables)
398 merge_parameters(dest=self._params, source=global_config)
400 if _LOG.isEnabledFor(logging.DEBUG):
401 _LOG.debug("Combined parameters:\n%s", json.dumps(self._params, indent=2))
403 return True
405 def teardown(self) -> None:
406 """
407 Tear down the benchmark environment.
409 This method must be idempotent, i.e., calling it several times in a row should
410 be equivalent to a single call.
411 """
412 _LOG.info("Teardown %s", self)
413 # Make sure we create a context before invoking setup/run/status/teardown
414 assert self._in_context
415 self._is_ready = False
417 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
418 """
419 Execute the run script for this environment.
421 For instance, this may start a new experiment, download results, reconfigure
422 the environment, etc. Details are configurable via the environment config.
424 Returns
425 -------
426 (status, timestamp, output) : (Status, datetime.datetime, dict)
427 3-tuple of (Status, timestamp, output) values, where `output` is a dict
428 with the results or None if the status is not COMPLETED.
429 If run script is a benchmark, then the score is usually expected to
430 be in the `score` field.
431 """
432 # Make sure we create a context before invoking setup/run/status/teardown
433 assert self._in_context
434 (status, timestamp, _) = self.status()
435 return (status, timestamp, None)
437 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]:
438 """
439 Check the status of the benchmark environment.
441 Returns
442 -------
443 (benchmark_status, timestamp, telemetry) : (Status, datetime.datetime, list)
444 3-tuple of (benchmark status, timestamp, telemetry) values.
445 `timestamp` is UTC time stamp of the status; it's current time by default.
446 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets.
447 """
448 # Make sure we create a context before invoking setup/run/status/teardown
449 assert self._in_context
450 timestamp = datetime.now(UTC)
451 if self._is_ready:
452 return (Status.READY, timestamp, [])
453 _LOG.warning("Environment not ready: %s", self)
454 return (Status.PENDING, timestamp, [])