Coverage for mlos_bench/mlos_bench/tunables/tunable_groups.py: 99%
88 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""TunableGroups definition."""
6import copy
7from typing import Dict, Generator, Iterable, Mapping, Optional, Tuple, Union
9from mlos_bench.config.schemas import ConfigSchema
10from mlos_bench.tunables.covariant_group import CovariantTunableGroup
11from mlos_bench.tunables.tunable import Tunable, TunableValue
14class TunableGroups:
15 """A collection of covariant groups of tunable parameters."""
17 def __init__(self, config: Optional[dict] = None):
18 """
19 Create a new group of tunable parameters.
21 Parameters
22 ----------
23 config : dict
24 Python dict of serialized representation of the covariant tunable groups.
25 """
26 if config is None:
27 config = {}
28 ConfigSchema.TUNABLE_PARAMS.validate(config)
29 # Index (Tunable id -> CovariantTunableGroup)
30 self._index: Dict[str, CovariantTunableGroup] = {}
31 self._tunable_groups: Dict[str, CovariantTunableGroup] = {}
32 for name, group_config in config.items():
33 self._add_group(CovariantTunableGroup(name, group_config))
35 def __bool__(self) -> bool:
36 return bool(self._index)
38 def __len__(self) -> int:
39 return len(self._index)
41 def __eq__(self, other: object) -> bool:
42 """
43 Check if two TunableGroups are equal.
45 Parameters
46 ----------
47 other : TunableGroups
48 A tunable groups object to compare to.
50 Returns
51 -------
52 is_equal : bool
53 True if two TunableGroups are equal.
54 """
55 if not isinstance(other, TunableGroups):
56 return False
57 return bool(self._tunable_groups == other._tunable_groups)
59 def copy(self) -> "TunableGroups":
60 """
61 Deep copy of the TunableGroups object.
63 Returns
64 -------
65 tunables : TunableGroups
66 A new instance of the TunableGroups object
67 that is a deep copy of the original one.
68 """
69 return copy.deepcopy(self)
71 def _add_group(self, group: CovariantTunableGroup) -> None:
72 """
73 Add a CovariantTunableGroup to the current collection.
75 Note: non-overlapping groups are expected to be added to the collection.
77 Parameters
78 ----------
79 group : CovariantTunableGroup
80 """
81 assert (
82 group.name not in self._tunable_groups
83 ), f"Duplicate covariant tunable group name {group.name} in {self}"
84 self._tunable_groups[group.name] = group
85 for tunable in group.get_tunables():
86 if tunable.name in self._index:
87 raise ValueError(
88 f"Duplicate Tunable {tunable.name} from group {group.name} in {self}"
89 )
90 self._index[tunable.name] = group
92 def merge(self, tunables: "TunableGroups") -> "TunableGroups":
93 """
94 Merge the two collections of covariant tunable groups.
96 Unlike the dict `update` method, this method does not modify the
97 original when overlapping keys are found.
98 It is expected be used to merge the tunable groups referenced by a
99 standalone Environment config into a parent CompositeEnvironment,
100 for instance.
101 This allows self contained, potentially overlapping, but also
102 overridable configs to be composed together.
104 Parameters
105 ----------
106 tunables : TunableGroups
107 A collection of covariant tunable groups.
109 Returns
110 -------
111 self : TunableGroups
112 Self-reference for chaining.
113 """
114 # pylint: disable=protected-access
115 # Check that covariant groups are unique, else throw an error.
116 for group in tunables._tunable_groups.values():
117 if group.name not in self._tunable_groups:
118 self._add_group(group)
119 else:
120 # Check that there's no overlap in the tunables.
121 # But allow for differing current values.
122 if not self._tunable_groups[group.name].equals_defaults(group):
123 raise ValueError(
124 f"Overlapping covariant tunable group name {group.name} "
125 "in {self._tunable_groups[group.name]} and {tunables}"
126 )
127 return self
129 def __repr__(self) -> str:
130 """
131 Produce a human-readable version of the TunableGroups (mostly for logging).
133 Returns
134 -------
135 string : str
136 A human-readable version of the TunableGroups.
137 """
138 return (
139 "{ "
140 + ", ".join(
141 f"{group.name}::{tunable}"
142 for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name))
143 for tunable in sorted(group._tunables.values())
144 )
145 + " }"
146 )
148 def __contains__(self, tunable: Union[str, Tunable]) -> bool:
149 """Checks if the given name/tunable is in this tunable group."""
150 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
151 return name in self._index
153 def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue:
154 """Get the current value of a single tunable parameter."""
155 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
156 return self._index[name][name]
158 def __setitem__(
159 self,
160 tunable: Union[str, Tunable],
161 tunable_value: Union[TunableValue, Tunable],
162 ) -> TunableValue:
163 """Update the current value of a single tunable parameter."""
164 # Use double index to make sure we set the is_updated flag of the group
165 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
166 value: TunableValue = (
167 tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value
168 )
169 self._index[name][name] = value
170 return self._index[name][name]
172 def __iter__(self) -> Generator[Tuple[Tunable, CovariantTunableGroup], None, None]:
173 """
174 An iterator over all tunables in the group.
176 Returns
177 -------
178 [(tunable, group), ...] : iter(Tunable, CovariantTunableGroup)
179 An iterator over all tunables in all groups. Each element is a 2-tuple
180 of an instance of the Tunable parameter and covariant group it belongs to.
181 """
182 return ((group.get_tunable(name), group) for (name, group) in self._index.items())
184 def get_tunable(self, tunable: Union[str, Tunable]) -> Tuple[Tunable, CovariantTunableGroup]:
185 """
186 Access the entire Tunable (not just its value) and its covariant group. Throw
187 KeyError if the tunable is not found.
189 Parameters
190 ----------
191 tunable : Union[str, Tunable]
192 Name of the tunable parameter.
194 Returns
195 -------
196 (tunable, group) : (Tunable, CovariantTunableGroup)
197 A 2-tuple of an instance of the Tunable parameter and covariant group it belongs to.
198 """
199 name: str = tunable.name if isinstance(tunable, Tunable) else tunable
200 group = self._index[name]
201 return (group.get_tunable(name), group)
203 def get_covariant_group_names(self) -> Iterable[str]:
204 """
205 Get the names of all covariance groups in the collection.
207 Returns
208 -------
209 group_names : [str]
210 IDs of the covariant tunable groups.
211 """
212 return self._tunable_groups.keys()
214 def subgroup(self, group_names: Iterable[str]) -> "TunableGroups":
215 """
216 Select the covariance groups from the current set and create a new TunableGroups
217 object that consists of those covariance groups.
219 Note: The new TunableGroup will include *references* (not copies) to
220 original ones, so each will get updated together.
221 This is often desirable to support the use case of multiple related
222 Environments (e.g. Local vs Remote) using the same set of tunables
223 within a CompositeEnvironment.
225 Parameters
226 ----------
227 group_names : list of str
228 IDs of the covariant tunable groups.
230 Returns
231 -------
232 tunables : TunableGroups
233 A collection of covariant tunable groups.
234 """
235 # pylint: disable=protected-access
236 tunables = TunableGroups()
237 for name in group_names:
238 if name not in self._tunable_groups:
239 raise KeyError(f"Unknown covariant group name '{name}' in tunable group {self}")
240 tunables._add_group(self._tunable_groups[name])
241 return tunables
243 def get_param_values(
244 self,
245 group_names: Optional[Iterable[str]] = None,
246 into_params: Optional[Dict[str, TunableValue]] = None,
247 ) -> Dict[str, TunableValue]:
248 """
249 Get the current values of the tunables that belong to the specified covariance
250 groups.
252 Parameters
253 ----------
254 group_names : list of str or None
255 IDs of the covariant tunable groups.
256 Select parameters from all groups if omitted.
257 into_params : dict
258 An optional dict to copy the parameters and their values into.
260 Returns
261 -------
262 into_params : dict
263 Flat dict of all parameters and their values from given covariance groups.
264 """
265 if group_names is None:
266 group_names = self.get_covariant_group_names()
267 if into_params is None:
268 into_params = {}
269 for name in group_names:
270 into_params.update(self._tunable_groups[name].get_tunable_values_dict())
271 return into_params
273 def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool:
274 """
275 Check if any of the given covariant tunable groups has been updated.
277 Parameters
278 ----------
279 group_names : list of str or None
280 IDs of the (covariant) tunable groups. Check all groups if omitted.
282 Returns
283 -------
284 is_updated : bool
285 True if any of the specified tunable groups has been updated, False otherwise.
286 """
287 return any(
288 self._tunable_groups[name].is_updated()
289 for name in (group_names or self.get_covariant_group_names())
290 )
292 def is_defaults(self) -> bool:
293 """
294 Checks whether the currently assigned values of all tunables are at their
295 defaults.
297 Returns
298 -------
299 bool
300 """
301 return all(group.is_defaults() for group in self._tunable_groups.values())
303 def restore_defaults(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups":
304 """
305 Restore all tunable parameters to their default values.
307 Parameters
308 ----------
309 group_names : list of str or None
310 IDs of the (covariant) tunable groups. Restore all groups if omitted.
312 Returns
313 -------
314 self : TunableGroups
315 Self-reference for chaining.
316 """
317 for name in group_names or self.get_covariant_group_names():
318 self._tunable_groups[name].restore_defaults()
319 return self
321 def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups":
322 """
323 Clear the update flag of given covariant groups.
325 Parameters
326 ----------
327 group_names : list of str or None
328 IDs of the (covariant) tunable groups. Reset all groups if omitted.
330 Returns
331 -------
332 self : TunableGroups
333 Self-reference for chaining.
334 """
335 for name in group_names or self.get_covariant_group_names():
336 self._tunable_groups[name].reset_is_updated()
337 return self
339 def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups":
340 """
341 In-place update the values of the tunables from the dictionary of (key, value)
342 pairs.
344 Parameters
345 ----------
346 param_values : Mapping[str, TunableValue]
347 Dictionary mapping Tunable parameter names to new values.
349 Returns
350 -------
351 self : TunableGroups
352 Self-reference for chaining.
353 """
354 for key, value in param_values.items():
355 self[key] = value
356 return self