Coverage for mlos_bench/mlos_bench/environments/composite_env.py: 89%
96 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Composite benchmark environment."""
7import logging
8from datetime import datetime
9from types import TracebackType
10from typing import Any, Dict, List, Optional, Tuple, Type
12from typing_extensions import Literal
14from mlos_bench.environments.base_environment import Environment
15from mlos_bench.environments.status import Status
16from mlos_bench.services.base_service import Service
17from mlos_bench.tunables.tunable import TunableValue
18from mlos_bench.tunables.tunable_groups import TunableGroups
20_LOG = logging.getLogger(__name__)
23class CompositeEnv(Environment):
24 """Composite benchmark environment."""
26 def __init__( # pylint: disable=too-many-arguments
27 self,
28 *,
29 name: str,
30 config: dict,
31 global_config: Optional[dict] = None,
32 tunables: Optional[TunableGroups] = None,
33 service: Optional[Service] = None,
34 ):
35 """
36 Create a new environment with a given config.
38 Parameters
39 ----------
40 name: str
41 Human-readable name of the environment.
42 config : dict
43 Free-format dictionary that contains the environment
44 configuration. Must have a "children" section.
45 global_config : dict
46 Free-format dictionary of global parameters (e.g., security credentials)
47 to be mixed in into the "const_args" section of the local config.
48 tunables : TunableGroups
49 A collection of groups of tunable parameters for *all* environments.
50 service: Service
51 An optional service object (e.g., providing methods to
52 deploy or reboot a VM, etc.).
53 """
54 super().__init__(
55 name=name,
56 config=config,
57 global_config=global_config,
58 tunables=tunables,
59 service=service,
60 )
62 # By default, the Environment includes only the tunables explicitly specified
63 # in the "tunable_params" section of the config. `CompositeEnv`, however, must
64 # retain all tunables from its children environments plus the ones that come
65 # from the "include_tunables".
66 tunables = tunables.copy() if tunables else TunableGroups()
68 _LOG.debug("Build composite environment '%s' START: %s", self, tunables)
69 self._children: List[Environment] = []
70 self._child_contexts: List[Environment] = []
72 # To support trees of composite environments (e.g. for multiple VM experiments),
73 # each CompositeEnv gets a copy of the original global config and adjusts it with
74 # the `const_args` specific to it.
75 global_config = (global_config or {}).copy()
76 for key, val in self._const_args.items():
77 global_config.setdefault(key, val)
79 for child_config_file in config.get("include_children", []):
80 for env in self._config_loader_service.load_environment_list(
81 child_config_file,
82 tunables,
83 global_config,
84 self._const_args,
85 self._service,
86 ):
87 self._add_child(env, tunables)
89 for child_config in config.get("children", []):
90 env = self._config_loader_service.build_environment(
91 child_config,
92 tunables,
93 global_config,
94 self._const_args,
95 self._service,
96 )
97 self._add_child(env, tunables)
99 _LOG.debug("Build composite environment '%s' END: %s", self, self._tunable_params)
101 if not self._children:
102 raise ValueError("At least one child environment must be present")
104 def __enter__(self) -> Environment:
105 self._child_contexts = [env.__enter__() for env in self._children]
106 return super().__enter__()
108 def __exit__(
109 self,
110 ex_type: Optional[Type[BaseException]],
111 ex_val: Optional[BaseException],
112 ex_tb: Optional[TracebackType],
113 ) -> Literal[False]:
114 ex_throw = None
115 for env in reversed(self._children):
116 try:
117 env.__exit__(ex_type, ex_val, ex_tb)
118 # pylint: disable=broad-exception-caught
119 except Exception as ex:
120 _LOG.error("Exception while exiting child environment '%s': %s", env, ex)
121 ex_throw = ex
122 self._child_contexts = []
123 super().__exit__(ex_type, ex_val, ex_tb)
124 if ex_throw:
125 raise ex_throw
126 return False
128 @property
129 def children(self) -> List[Environment]:
130 """Return the list of child environments."""
131 return self._children
133 def pprint(self, indent: int = 4, level: int = 0) -> str:
134 """
135 Pretty-print the environment and its children.
137 Parameters
138 ----------
139 indent : int
140 Number of spaces to indent the output at each level. Default is 4.
141 level : int
142 Current level of indentation. Default is 0.
144 Returns
145 -------
146 pretty : str
147 Pretty-printed environment configuration.
148 """
149 return (
150 super().pprint(indent, level)
151 + "\n"
152 + "\n".join(child.pprint(indent, level + 1) for child in self._children)
153 )
155 def _add_child(self, env: Environment, tunables: TunableGroups) -> None:
156 """
157 Add a new child environment to the composite environment.
159 This method is called from the constructor only.
160 """
161 _LOG.debug("Merge tunables: '%s' <- '%s' :: %s", self, env, env.tunable_params)
162 self._children.append(env)
163 self._tunable_params.merge(env.tunable_params)
164 tunables.merge(env.tunable_params)
166 def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
167 """
168 Set up the children environments.
170 Parameters
171 ----------
172 tunables : TunableGroups
173 A collection of tunable parameters along with their values.
174 global_config : dict
175 Free-format dictionary of global parameters of the environment
176 that are not used in the optimization process.
178 Returns
179 -------
180 is_success : bool
181 True if all children setup() operations are successful,
182 false otherwise.
183 """
184 assert self._in_context
185 self._is_ready = super().setup(tunables, global_config) and all(
186 env_context.setup(tunables, global_config) for env_context in self._child_contexts
187 )
188 return self._is_ready
190 def teardown(self) -> None:
191 """
192 Tear down the children environments.
194 This method is idempotent, i.e., calling it several times is equivalent to a
195 single call. The environments are being torn down in the reverse order.
196 """
197 assert self._in_context
198 for env_context in reversed(self._child_contexts):
199 env_context.teardown()
200 super().teardown()
202 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
203 """
204 Submit a new experiment to the environment. Return the result of the *last*
205 child environment if successful, or the status of the last failed environment
206 otherwise.
208 Returns
209 -------
210 (status, timestamp, output) : (Status, datetime, dict)
211 3-tuple of (Status, timestamp, output) values, where `output` is a dict
212 with the results or None if the status is not COMPLETED.
213 If run script is a benchmark, then the score is usually expected to
214 be in the `score` field.
215 """
216 _LOG.info("Run: %s", self._children)
217 (status, timestamp, metrics) = super().run()
218 if not status.is_ready():
219 return (status, timestamp, metrics)
221 joint_metrics = {}
222 for env_context in self._child_contexts:
223 _LOG.debug("Child env. run: %s", env_context)
224 (status, timestamp, metrics) = env_context.run()
225 _LOG.debug("Child env. run results: %s :: %s %s", env_context, status, metrics)
226 if not status.is_good():
227 _LOG.info("Run failed: %s :: %s", self, status)
228 return (status, timestamp, None)
229 joint_metrics.update(metrics or {})
231 _LOG.info("Run completed: %s :: %s %s", self, status, joint_metrics)
232 # Return the status and the timestamp of the last child environment.
233 return (status, timestamp, joint_metrics)
235 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]:
236 """
237 Check the status of the benchmark environment.
239 Returns
240 -------
241 (benchmark_status, timestamp, telemetry) : (Status, datetime, list)
242 3-tuple of (benchmark status, timestamp, telemetry) values.
243 `timestamp` is UTC time stamp of the status; it's current time by default.
244 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets.
245 """
246 (status, timestamp, telemetry) = super().status()
247 if not status.is_ready():
248 return (status, timestamp, telemetry)
250 joint_telemetry = []
251 final_status = None
252 for env_context in self._child_contexts:
253 (status, timestamp, telemetry) = env_context.status()
254 _LOG.debug("Child env. status: %s :: %s", env_context, status)
255 joint_telemetry.extend(telemetry)
256 if not status.is_good() and final_status is None:
257 final_status = status
259 final_status = final_status or status
260 _LOG.info("Final status: %s :: %s", self, final_status)
261 # Return the status and the timestamp of the last child environment or the
262 # first failed child environment.
263 return (final_status, timestamp, joint_telemetry)