Coverage for mlos_bench/mlos_bench/environments/composite_env.py: 88%
95 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Composite benchmark environment."""
7import logging
8from datetime import datetime
9from types import TracebackType
10from typing import Any, Dict, List, Literal, Optional, Tuple, Type
12from mlos_bench.environments.base_environment import Environment
13from mlos_bench.environments.status import Status
14from mlos_bench.services.base_service import Service
15from mlos_bench.tunables.tunable import TunableValue
16from mlos_bench.tunables.tunable_groups import TunableGroups
18_LOG = logging.getLogger(__name__)
21class CompositeEnv(Environment):
22 """Composite benchmark environment."""
24 def __init__( # pylint: disable=too-many-arguments
25 self,
26 *,
27 name: str,
28 config: dict,
29 global_config: Optional[dict] = None,
30 tunables: Optional[TunableGroups] = None,
31 service: Optional[Service] = None,
32 ):
33 """
34 Create a new environment with a given config.
36 Parameters
37 ----------
38 name: str
39 Human-readable name of the environment.
40 config : dict
41 Free-format dictionary that contains the environment
42 configuration. Must have a "children" section.
43 global_config : dict
44 Free-format dictionary of global parameters (e.g., security credentials)
45 to be mixed in into the "const_args" section of the local config.
46 tunables : TunableGroups
47 A collection of groups of tunable parameters for *all* environments.
48 service: Service
49 An optional service object (e.g., providing methods to
50 deploy or reboot a VM, etc.).
51 """
52 super().__init__(
53 name=name,
54 config=config,
55 global_config=global_config,
56 tunables=tunables,
57 service=service,
58 )
60 # By default, the Environment includes only the tunables explicitly specified
61 # in the "tunable_params" section of the config. `CompositeEnv`, however, must
62 # retain all tunables from its children environments plus the ones that come
63 # from the "include_tunables".
64 tunables = tunables.copy() if tunables else TunableGroups()
66 _LOG.debug("Build composite environment '%s' START: %s", self, tunables)
67 self._children: List[Environment] = []
68 self._child_contexts: List[Environment] = []
70 # To support trees of composite environments (e.g. for multiple VM experiments),
71 # each CompositeEnv gets a copy of the original global config and adjusts it with
72 # the `const_args` specific to it.
73 global_config = (global_config or {}).copy()
74 for key, val in self._const_args.items():
75 global_config.setdefault(key, val)
77 for child_config_file in config.get("include_children", []):
78 for env in self._config_loader_service.load_environment_list(
79 child_config_file,
80 tunables,
81 global_config,
82 self._const_args,
83 self._service,
84 ):
85 self._add_child(env, tunables)
87 for child_config in config.get("children", []):
88 env = self._config_loader_service.build_environment(
89 child_config,
90 tunables,
91 global_config,
92 self._const_args,
93 self._service,
94 )
95 self._add_child(env, tunables)
97 _LOG.debug("Build composite environment '%s' END: %s", self, self._tunable_params)
99 if not self._children:
100 raise ValueError("At least one child environment must be present")
102 def __enter__(self) -> Environment:
103 self._child_contexts = [env.__enter__() for env in self._children]
104 return super().__enter__()
106 def __exit__(
107 self,
108 ex_type: Optional[Type[BaseException]],
109 ex_val: Optional[BaseException],
110 ex_tb: Optional[TracebackType],
111 ) -> Literal[False]:
112 ex_throw = None
113 for env in reversed(self._children):
114 try:
115 env.__exit__(ex_type, ex_val, ex_tb)
116 # pylint: disable=broad-exception-caught
117 except Exception as ex:
118 _LOG.error("Exception while exiting child environment '%s': %s", env, ex)
119 ex_throw = ex
120 self._child_contexts = []
121 super().__exit__(ex_type, ex_val, ex_tb)
122 if ex_throw:
123 raise ex_throw
124 return False
126 @property
127 def children(self) -> List[Environment]:
128 """Return the list of child environments."""
129 return self._children
131 def pprint(self, indent: int = 4, level: int = 0) -> str:
132 """
133 Pretty-print the environment and its children.
135 Parameters
136 ----------
137 indent : int
138 Number of spaces to indent the output at each level. Default is 4.
139 level : int
140 Current level of indentation. Default is 0.
142 Returns
143 -------
144 pretty : str
145 Pretty-printed environment configuration.
146 """
147 return (
148 super().pprint(indent, level)
149 + "\n"
150 + "\n".join(child.pprint(indent, level + 1) for child in self._children)
151 )
153 def _add_child(self, env: Environment, tunables: TunableGroups) -> None:
154 """
155 Add a new child environment to the composite environment.
157 This method is called from the constructor only.
158 """
159 _LOG.debug("Merge tunables: '%s' <- '%s' :: %s", self, env, env.tunable_params)
160 self._children.append(env)
161 self._tunable_params.merge(env.tunable_params)
162 tunables.merge(env.tunable_params)
164 def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool:
165 """
166 Set up the children environments.
168 Parameters
169 ----------
170 tunables : TunableGroups
171 A collection of tunable parameters along with their values.
172 global_config : dict
173 Free-format dictionary of global parameters of the environment
174 that are not used in the optimization process.
176 Returns
177 -------
178 is_success : bool
179 True if all children setup() operations are successful,
180 false otherwise.
181 """
182 assert self._in_context
183 self._is_ready = super().setup(tunables, global_config) and all(
184 env_context.setup(tunables, global_config) for env_context in self._child_contexts
185 )
186 return self._is_ready
188 def teardown(self) -> None:
189 """
190 Tear down the children environments.
192 This method is idempotent, i.e., calling it several times is equivalent to a
193 single call. The environments are being torn down in the reverse order.
194 """
195 assert self._in_context
196 for env_context in reversed(self._child_contexts):
197 env_context.teardown()
198 super().teardown()
200 def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]:
201 """
202 Submit a new experiment to the environment. Return the result of the *last*
203 child environment if successful, or the status of the last failed environment
204 otherwise.
206 Returns
207 -------
208 (status, timestamp, output) : (Status, datetime.datetime, dict)
209 3-tuple of (Status, timestamp, output) values, where `output` is a dict
210 with the results or None if the status is not COMPLETED.
211 If run script is a benchmark, then the score is usually expected to
212 be in the `score` field.
213 """
214 _LOG.info("Run: %s", self._children)
215 (status, timestamp, metrics) = super().run()
216 if not status.is_ready():
217 return (status, timestamp, metrics)
219 joint_metrics = {}
220 for env_context in self._child_contexts:
221 _LOG.debug("Child env. run: %s", env_context)
222 (status, timestamp, metrics) = env_context.run()
223 _LOG.debug("Child env. run results: %s :: %s %s", env_context, status, metrics)
224 if not status.is_good():
225 _LOG.info("Run failed: %s :: %s", self, status)
226 return (status, timestamp, None)
227 joint_metrics.update(metrics or {})
229 _LOG.info("Run completed: %s :: %s %s", self, status, joint_metrics)
230 # Return the status and the timestamp of the last child environment.
231 return (status, timestamp, joint_metrics)
233 def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]:
234 """
235 Check the status of the benchmark environment.
237 Returns
238 -------
239 (benchmark_status, timestamp, telemetry) : (Status, datetime.datetime, list)
240 3-tuple of (benchmark status, timestamp, telemetry) values.
241 `timestamp` is UTC time stamp of the status; it's current time by default.
242 `telemetry` is a list (maybe empty) of (timestamp, metric, value) triplets.
243 """
244 (status, timestamp, telemetry) = super().status()
245 if not status.is_ready():
246 return (status, timestamp, telemetry)
248 joint_telemetry = []
249 final_status = None
250 for env_context in self._child_contexts:
251 (status, timestamp, telemetry) = env_context.status()
252 _LOG.debug("Child env. status: %s :: %s", env_context, status)
253 joint_telemetry.extend(telemetry)
254 if not status.is_good() and final_status is None:
255 final_status = status
257 final_status = final_status or status
258 _LOG.info("Final status: %s :: %s", self, final_status)
259 # Return the status and the timestamp of the last child environment or the
260 # first failed child environment.
261 return (final_status, timestamp, joint_telemetry)