Coverage for mlos_bench/mlos_bench/tunables/tunable_groups.py: 99%
93 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""TunableGroups definition."""
6import copy
7import logging
8from typing import Dict, Generator, Iterable, Mapping, Optional, Tuple, Union
10from mlos_bench.config.schemas import ConfigSchema
11from mlos_bench.tunables.covariant_group import CovariantTunableGroup
12from mlos_bench.tunables.tunable import Tunable, TunableValue
14_LOG = logging.getLogger(__name__)
17class TunableGroups:
18 """A collection of :py:class:`.CovariantTunableGroup` s of :py:class:`.Tunable`
19 parameters.
20 """
22 def __init__(self, config: Optional[dict] = None):
23 """
24 Create a new group of tunable parameters.
26 Parameters
27 ----------
28 config : dict
29 Python dict of serialized representation of the covariant tunable groups.
31 See Also
32 --------
33 :py:mod:`mlos_bench.tunables` : for more information on tunable parameters and
34 their configuration.
35 """
36 if config is None:
37 config = {}
38 ConfigSchema.TUNABLE_PARAMS.validate(config)
39 # Index (Tunable id -> CovariantTunableGroup)
40 self._index: Dict[str, CovariantTunableGroup] = {}
41 self._tunable_groups: Dict[str, CovariantTunableGroup] = {}
42 for name, group_config in config.items():
43 self._add_group(CovariantTunableGroup(name, group_config))
45 def __bool__(self) -> bool:
46 return bool(self._index)
48 def __len__(self) -> int:
49 return len(self._index)
51 def __eq__(self, other: object) -> bool:
52 """
53 Check if two TunableGroups are equal.
55 Parameters
56 ----------
57 other : TunableGroups
58 A tunable groups object to compare to.
60 Returns
61 -------
62 is_equal : bool
63 True if two TunableGroups are equal.
64 """
65 if not isinstance(other, TunableGroups):
66 return False
67 return bool(self._tunable_groups == other._tunable_groups)
69 def copy(self) -> "TunableGroups":
70 """
71 Deep copy of the TunableGroups object.
73 Returns
74 -------
75 tunables : TunableGroups
76 A new instance of the TunableGroups object
77 that is a deep copy of the original one.
78 """
79 return copy.deepcopy(self)
81 def _add_group(self, group: CovariantTunableGroup) -> None:
82 """
83 Add a CovariantTunableGroup to the current collection.
85 Note: non-overlapping groups are expected to be added to the collection.
87 Parameters
88 ----------
89 group : CovariantTunableGroup
90 """
91 assert (
92 group.name not in self._tunable_groups
93 ), f"Duplicate covariant tunable group name {group.name} in {self}"
94 self._tunable_groups[group.name] = group
95 for tunable in group.get_tunables():
96 if tunable.name in self._index:
97 raise ValueError(
98 f"Duplicate Tunable {tunable.name} from group {group.name} in {self}"
99 )
100 self._index[tunable.name] = group
102 def merge(self, tunables: "TunableGroups") -> "TunableGroups":
103 """
104 Merge the two collections of covariant tunable groups.
106 Unlike the dict `update` method, this method does not modify the
107 original when overlapping keys are found.
108 It is expected be used to merge the tunable groups referenced by a
109 standalone Environment config into a parent CompositeEnvironment,
110 for instance.
111 This allows self contained, potentially overlapping, but also
112 overridable configs to be composed together.
114 Parameters
115 ----------
116 tunables : TunableGroups
117 A collection of covariant tunable groups.
119 Returns
120 -------
121 self : TunableGroups
122 Self-reference for chaining.
123 """
124 # pylint: disable=protected-access
125 # Check that covariant groups are unique, else throw an error.
126 for group in tunables._tunable_groups.values():
127 if group.name not in self._tunable_groups:
128 self._add_group(group)
129 else:
130 # Check that there's no overlap in the tunables.
131 # But allow for differing current values.
132 if not self._tunable_groups[group.name].equals_defaults(group):
133 raise ValueError(
134 f"Overlapping covariant tunable group name {group.name} "
135 "in {self._tunable_groups[group.name]} and {tunables}"
136 )
137 return self
139 def __repr__(self) -> str:
140 """
141 Produce a human-readable version of the TunableGroups (mostly for logging).
143 Returns
144 -------
145 string : str
146 A human-readable version of the TunableGroups.
147 """
148 return (
149 "{ "
150 + ", ".join(
151 f"{group.name}::{tunable}"
152 for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name))
153 for tunable in sorted(group._tunables.values())
154 )
155 + " }"
156 )
158 def __contains__(self, tunable: Union[str, Tunable]) -> bool:
159 """Checks if the given name/tunable is in this tunable group."""
160 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
161 return name in self._index
163 def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue:
164 """Get the current value of a single tunable parameter."""
165 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
166 return self._index[name][name]
168 def __setitem__(
169 self,
170 tunable: Union[str, Tunable],
171 tunable_value: Union[TunableValue, Tunable],
172 ) -> TunableValue:
173 """Update the current value of a single tunable parameter."""
174 # Use double index to make sure we set the is_updated flag of the group
175 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
176 value: TunableValue = (
177 tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value
178 )
179 self._index[name][name] = value
180 return self._index[name][name]
182 def __iter__(self) -> Generator[Tuple[Tunable, CovariantTunableGroup], None, None]:
183 """
184 An iterator over all tunables in the group.
186 Returns
187 -------
188 [(tunable, group), ...] : Generator[Tuple[Tunable, CovariantTunableGroup], None, None]
189 An iterator over all tunables in all groups. Each element is a 2-tuple
190 of an instance of the Tunable parameter and covariant group it belongs to.
191 """
192 return ((group.get_tunable(name), group) for (name, group) in self._index.items())
194 def get_tunable(self, tunable: Union[str, Tunable]) -> Tuple[Tunable, CovariantTunableGroup]:
195 """
196 Access the entire Tunable (not just its value) and its covariant group. Throw
197 KeyError if the tunable is not found.
199 Parameters
200 ----------
201 tunable : Union[str, Tunable]
202 Name of the tunable parameter.
204 Returns
205 -------
206 (tunable, group) : (Tunable, CovariantTunableGroup)
207 A 2-tuple of an instance of the Tunable parameter and covariant group it belongs to.
208 """
209 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
210 group = self._index[name]
211 return (group.get_tunable(name), group)
213 def get_covariant_group_names(self) -> Iterable[str]:
214 """
215 Get the names of all covariance groups in the collection.
217 Returns
218 -------
219 group_names : [str]
220 IDs of the covariant tunable groups.
221 """
222 return self._tunable_groups.keys()
224 def subgroup(self, group_names: Iterable[str]) -> "TunableGroups":
225 """
226 Select the covariance groups from the current set and create a new TunableGroups
227 object that consists of those covariance groups.
229 Note: The new TunableGroup will include *references* (not copies) to
230 original ones, so each will get updated together.
231 This is often desirable to support the use case of multiple related
232 Environments (e.g. Local vs Remote) using the same set of tunables
233 within a CompositeEnvironment.
235 Parameters
236 ----------
237 group_names : list of str
238 IDs of the covariant tunable groups.
240 Returns
241 -------
242 tunables : TunableGroups
243 A collection of covariant tunable groups.
244 """
245 # pylint: disable=protected-access
246 tunables = TunableGroups()
247 for name in group_names:
248 if name not in self._tunable_groups:
249 raise KeyError(f"Unknown covariant group name '{name}' in tunable group {self}")
250 tunables._add_group(self._tunable_groups[name])
251 return tunables
253 def get_param_values(
254 self,
255 group_names: Optional[Iterable[str]] = None,
256 into_params: Optional[Dict[str, TunableValue]] = None,
257 ) -> Dict[str, TunableValue]:
258 """
259 Get the current values of the tunables that belong to the specified covariance
260 groups.
262 Parameters
263 ----------
264 group_names : list of str or None
265 IDs of the covariant tunable groups.
266 Select parameters from all groups if omitted.
267 into_params : dict
268 An optional dict to copy the parameters and their values into.
270 Returns
271 -------
272 into_params : dict
273 Flat dict of all parameters and their values from given covariance groups.
274 """
275 if group_names is None:
276 group_names = self.get_covariant_group_names()
277 if into_params is None:
278 into_params = {}
279 for name in group_names:
280 into_params.update(self._tunable_groups[name].get_tunable_values_dict())
281 return into_params
283 def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool:
284 """
285 Check if any of the given covariant tunable groups has been updated.
287 Parameters
288 ----------
289 group_names : list of str or None
290 IDs of the (covariant) tunable groups. Check all groups if omitted.
292 Returns
293 -------
294 is_updated : bool
295 True if any of the specified tunable groups has been updated, False otherwise.
296 """
297 return any(
298 self._tunable_groups[name].is_updated()
299 for name in (group_names or self.get_covariant_group_names())
300 )
302 def is_defaults(self) -> bool:
303 """
304 Checks whether the currently assigned values of all tunables are at their
305 defaults.
307 Returns
308 -------
309 bool
310 """
311 return all(group.is_defaults() for group in self._tunable_groups.values())
313 def restore_defaults(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups":
314 """
315 Restore all tunable parameters to their default values.
317 Parameters
318 ----------
319 group_names : list of str or None
320 IDs of the (covariant) tunable groups. Restore all groups if omitted.
322 Returns
323 -------
324 self : TunableGroups
325 Self-reference for chaining.
326 """
327 for name in group_names or self.get_covariant_group_names():
328 self._tunable_groups[name].restore_defaults()
329 return self
331 def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups":
332 """
333 Clear the update flag of given covariant groups.
335 Parameters
336 ----------
337 group_names : list of str or None
338 IDs of the (covariant) tunable groups. Reset all groups if omitted.
340 Returns
341 -------
342 self : TunableGroups
343 Self-reference for chaining.
344 """
345 for name in group_names or self.get_covariant_group_names():
346 self._tunable_groups[name].reset_is_updated()
347 return self
349 def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups":
350 """
351 In-place update the values of the tunables from the dictionary of (key, value)
352 pairs.
354 Parameters
355 ----------
356 param_values : Mapping[str, TunableValue]
357 Dictionary mapping Tunable parameter names to new values.
359 As a special behavior when the mapping is empty the method will restore
360 the default values rather than no-op.
361 This allows an empty dictionary in json configs to be used to reset the
362 tunables to defaults without having to copy the original values from the
363 tunable_params definition.
365 Returns
366 -------
367 self : TunableGroups
368 Self-reference for chaining.
369 """
370 if not param_values:
371 _LOG.info("Empty tunable values set provided. Resetting all tunables to defaults.")
372 return self.restore_defaults()
373 for key, value in param_values.items():
374 self[key] = value
375 return self