Coverage for mlos_bench/mlos_bench/tunables/tunable.py: 97%
299 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Tunable parameter definition."""
6import collections
7import copy
8import logging
9from typing import (
10 Any,
11 Dict,
12 Iterable,
13 List,
14 Literal,
15 Optional,
16 Sequence,
17 Tuple,
18 Type,
19 TypedDict,
20 Union,
21)
23import numpy as np
25from mlos_bench.util import nullable
27_LOG = logging.getLogger(__name__)
29TunableValue = Union[int, float, Optional[str]]
30"""A tunable parameter value type alias."""
32TunableValueType = Union[Type[int], Type[float], Type[str]]
33"""Tunable value type."""
35TunableValueTypeTuple = (int, float, str, type(None))
36"""
37Tunable value type tuple.
39For checking with isinstance()
40"""
42TunableValueTypeName = Literal["int", "float", "categorical"]
43"""The string name of a tunable value type."""
45TunableValuesDict = Dict[str, TunableValue]
46"""Tunable values dictionary type."""
48DistributionName = Literal["uniform", "normal", "beta"]
49"""Tunable value distribution type."""
52class DistributionDict(TypedDict, total=False):
53 """A typed dict for tunable parameters' distributions."""
55 type: DistributionName
56 params: Optional[Dict[str, float]]
59class TunableDict(TypedDict, total=False):
60 """
61 A typed dict for tunable parameters.
63 Mostly used for mypy type checking.
65 These are the types expected to be received from the json config.
66 """
68 type: TunableValueTypeName
69 description: Optional[str]
70 default: TunableValue
71 values: Optional[List[Optional[str]]]
72 range: Optional[Union[Sequence[int], Sequence[float]]]
73 quantization_bins: Optional[int]
74 log: Optional[bool]
75 distribution: Optional[DistributionDict]
76 special: Optional[Union[List[int], List[float]]]
77 values_weights: Optional[List[float]]
78 special_weights: Optional[List[float]]
79 range_weight: Optional[float]
80 meta: Dict[str, Any]
83class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-methods
84 """A tunable parameter definition and its current value."""
86 # Maps tunable types to their corresponding Python types by name.
87 _DTYPE: Dict[TunableValueTypeName, TunableValueType] = {
88 "int": int,
89 "float": float,
90 "categorical": str,
91 }
93 def __init__(self, name: str, config: TunableDict):
94 """
95 Create an instance of a new tunable parameter.
97 Parameters
98 ----------
99 name : str
100 Human-readable identifier of the tunable parameter.
101 config : dict
102 Python dict that represents a Tunable (e.g., deserialized from JSON)
104 See Also
105 --------
106 :py:mod:`mlos_bench.tunables` : for more information on tunable parameters and
107 their configuration.
108 """
109 if not isinstance(name, str) or "!" in name: # TODO: Use a regex here and in JSON schema
110 raise ValueError(f"Invalid name of the tunable: {name}")
111 self._name = name
112 self._type: TunableValueTypeName = config["type"] # required
113 if self._type not in self._DTYPE:
114 raise ValueError(f"Invalid parameter type: {self._type}")
115 self._description = config.get("description")
116 self._default = config["default"]
117 self._default = self.dtype(self._default) if self._default is not None else self._default
118 self._values = config.get("values")
119 if self._values:
120 self._values = [str(v) if v is not None else v for v in self._values]
121 self._meta: Dict[str, Any] = config.get("meta", {})
122 self._range: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None
123 self._quantization_bins: Optional[int] = config.get("quantization_bins")
124 self._log: Optional[bool] = config.get("log")
125 self._distribution: Optional[DistributionName] = None
126 self._distribution_params: Dict[str, float] = {}
127 distr = config.get("distribution")
128 if distr:
129 self._distribution = distr["type"] # required
130 self._distribution_params = distr.get("params") or {}
131 config_range = config.get("range")
132 if config_range is not None:
133 assert len(config_range) == 2, f"Invalid range: {config_range}"
134 config_range = (config_range[0], config_range[1])
135 self._range = config_range
136 self._special: Union[List[int], List[float]] = config.get("special") or []
137 self._weights: List[float] = (
138 config.get("values_weights") or config.get("special_weights") or []
139 )
140 self._range_weight: Optional[float] = config.get("range_weight")
141 self._current_value = None
142 self._sanity_check()
143 self.value = self._default
145 def _sanity_check(self) -> None:
146 """Check if the status of the Tunable is valid, and throw ValueError if it is
147 not.
148 """
149 if self.is_categorical:
150 self._sanity_check_categorical()
151 elif self.is_numerical:
152 self._sanity_check_numerical()
153 else:
154 raise ValueError(f"Invalid parameter type for tunable {self}: {self._type}")
155 if not self.is_valid(self.default):
156 raise ValueError(f"Invalid default value for tunable {self}: {self.default}")
158 def _sanity_check_categorical(self) -> None:
159 """Check if the status of the categorical Tunable is valid, and throw ValueError
160 if it is not.
161 """
162 # pylint: disable=too-complex
163 assert self.is_categorical
164 if not (self._values and isinstance(self._values, collections.abc.Iterable)):
165 raise ValueError(f"Must specify values for the categorical type tunable {self}")
166 if self._range is not None:
167 raise ValueError(f"Range must be None for the categorical type tunable {self}")
168 if len(set(self._values)) != len(self._values):
169 raise ValueError(f"Values must be unique for the categorical type tunable {self}")
170 if self._special:
171 raise ValueError(f"Categorical tunable cannot have special values: {self}")
172 if self._range_weight is not None:
173 raise ValueError(f"Categorical tunable cannot have range_weight: {self}")
174 if self._log is not None:
175 raise ValueError(f"Categorical tunable cannot have log parameter: {self}")
176 if self._quantization_bins is not None:
177 raise ValueError(f"Categorical tunable cannot have quantization parameter: {self}")
178 if self._distribution is not None:
179 raise ValueError(f"Categorical parameters do not support `distribution`: {self}")
180 if self._weights:
181 if len(self._weights) != len(self._values):
182 raise ValueError(f"Must specify weights for all values: {self}")
183 if any(w < 0 for w in self._weights):
184 raise ValueError(f"All weights must be non-negative: {self}")
186 def _sanity_check_numerical(self) -> None:
187 """Check if the status of the numerical Tunable is valid, and throw ValueError
188 if it is not.
189 """
190 # pylint: disable=too-complex,too-many-branches
191 assert self.is_numerical
192 if self._values is not None:
193 raise ValueError(f"Values must be None for the numerical type tunable {self}")
194 if not self._range or len(self._range) != 2 or self._range[0] >= self._range[1]:
195 raise ValueError(f"Invalid range for tunable {self}: {self._range}")
196 if self._quantization_bins is not None and self._quantization_bins <= 1:
197 raise ValueError(f"Number of quantization bins is <= 1: {self}")
198 if self._distribution is not None and self._distribution not in {
199 "uniform",
200 "normal",
201 "beta",
202 }:
203 raise ValueError(f"Invalid distribution: {self}")
204 if self._distribution_params and self._distribution is None:
205 raise ValueError(f"Must specify the distribution: {self}")
206 if self._weights:
207 if self._range_weight is None:
208 raise ValueError(f"Must specify weight for the range: {self}")
209 if len(self._weights) != len(self._special):
210 raise ValueError("Must specify weights for all special values {self}")
211 if any(w < 0 for w in self._weights + [self._range_weight]):
212 raise ValueError(f"All weights must be non-negative: {self}")
213 elif self._range_weight is not None:
214 raise ValueError(f"Must specify both weights and range_weight or none: {self}")
216 def __repr__(self) -> str:
217 """
218 Produce a human-readable version of the Tunable (mostly for logging).
220 Returns
221 -------
222 string : str
223 A human-readable version of the Tunable.
224 """
225 # TODO? Add weights, specials, quantization, distribution?
226 if self.is_categorical:
227 return (
228 f"{self._name}[{self._type}]({self._values}:{self._default})={self._current_value}"
229 )
230 return f"{self._name}[{self._type}]({self._range}:{self._default})={self._current_value}"
232 def __eq__(self, other: object) -> bool:
233 """
234 Check if two Tunable objects are equal.
236 Parameters
237 ----------
238 other : Tunable
239 A tunable object to compare to.
241 Returns
242 -------
243 is_equal : bool
244 True if the Tunables correspond to the same parameter and have the same value and type.
245 NOTE: ranges and special values are not currently considered in the comparison.
246 """
247 if not isinstance(other, Tunable):
248 return False
249 return bool(
250 self._name == other._name
251 and self._type == other._type
252 and self._current_value == other._current_value
253 )
255 def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements
256 """
257 Compare the two Tunable objects. We mostly need this to create a canonical list
258 of tunable objects when hashing a TunableGroup.
260 Parameters
261 ----------
262 other : Tunable
263 A tunable object to compare to.
265 Returns
266 -------
267 is_less : bool
268 True if the current Tunable is less then the other one, False otherwise.
269 """
270 if not isinstance(other, Tunable):
271 return False
272 if self._name < other._name:
273 return True
274 if self._name == other._name and self._type < other._type:
275 return True
276 if self._name == other._name and self._type == other._type:
277 if self.is_numerical:
278 assert self._current_value is not None
279 assert other._current_value is not None
280 return bool(float(self._current_value) < float(other._current_value))
281 # else: categorical
282 if self._current_value is None:
283 return True
284 if other._current_value is None:
285 return False
286 return bool(str(self._current_value) < str(other._current_value))
287 return False
289 def copy(self) -> "Tunable":
290 """
291 Deep copy of the Tunable object.
293 Returns
294 -------
295 tunable : Tunable
296 A new Tunable object that is a deep copy of the original one.
297 """
298 return copy.deepcopy(self)
300 @property
301 def default(self) -> TunableValue:
302 """Get the default value of the tunable."""
303 return self._default
305 def is_default(self) -> TunableValue:
306 """Checks whether the currently assigned value of the tunable is at its
307 default.
308 """
309 return self._default == self._current_value
311 @property
312 def value(self) -> TunableValue:
313 """Get the current value of the tunable."""
314 return self._current_value
316 @value.setter
317 def value(self, value: TunableValue) -> TunableValue:
318 """Set the current value of the tunable."""
319 # We need this coercion for the values produced by some optimizers
320 # (e.g., scikit-optimize) and for data restored from certain storage
321 # systems (where values can be strings).
322 try:
323 if self.is_categorical and value is None:
324 coerced_value = None
325 else:
326 assert value is not None
327 coerced_value = self.dtype(value)
328 except Exception:
329 _LOG.error(
330 "Impossible conversion: %s %s <- %s %s",
331 self._type,
332 self._name,
333 type(value),
334 value,
335 )
336 raise
338 if self._type == "int" and isinstance(value, float) and value != coerced_value:
339 _LOG.error(
340 "Loss of precision: %s %s <- %s %s",
341 self._type,
342 self._name,
343 type(value),
344 value,
345 )
346 raise ValueError(f"Loss of precision: {self._name}={value}")
348 if not self.is_valid(coerced_value):
349 _LOG.error(
350 "Invalid assignment: %s %s <- %s %s",
351 self._type,
352 self._name,
353 type(value),
354 value,
355 )
356 raise ValueError(f"Invalid value for the tunable: {self._name}={value}")
358 self._current_value = coerced_value
359 return self._current_value
361 def update(self, value: TunableValue) -> bool:
362 """
363 Assign the value to the tunable. Return True if it is a new value, False
364 otherwise.
366 Parameters
367 ----------
368 value : Union[int, float, str]
369 Value to assign.
371 Returns
372 -------
373 is_updated : bool
374 True if the new value is different from the previous one, False otherwise.
375 """
376 prev_value = self._current_value
377 self.value = value
378 return prev_value != self._current_value
380 def is_valid(self, value: TunableValue) -> bool:
381 """
382 Check if the value can be assigned to the tunable.
384 Parameters
385 ----------
386 value : Union[int, float, str]
387 Value to validate.
389 Returns
390 -------
391 is_valid : bool
392 True if the value is valid, False otherwise.
393 """
394 if self.is_categorical and self._values:
395 return value in self._values
396 elif self.is_numerical and self._range:
397 if isinstance(value, (int, float)):
398 return self.in_range(value) or value in self._special
399 else:
400 raise ValueError(f"Invalid value type for tunable {self}: {value}={type(value)}")
401 else:
402 raise ValueError(f"Invalid parameter type: {self._type}")
404 def in_range(self, value: Union[int, float, str, None]) -> bool:
405 """
406 Check if the value is within the range of the tunable.
408 Do *NOT* check for special values. Return False if the tunable or value is
409 categorical or None.
410 """
411 return (
412 isinstance(value, (float, int))
413 and self.is_numerical
414 and self._range is not None
415 and bool(self._range[0] <= value <= self._range[1])
416 )
418 @property
419 def category(self) -> Optional[str]:
420 """Get the current value of the tunable as a string."""
421 if self.is_categorical:
422 return nullable(str, self._current_value)
423 else:
424 raise ValueError("Cannot get categorical values for a numerical tunable.")
426 @category.setter
427 def category(self, new_value: Optional[str]) -> Optional[str]:
428 """Set the current value of the tunable."""
429 assert self.is_categorical
430 assert isinstance(new_value, (str, type(None)))
431 self.value = new_value
432 return self.value
434 @property
435 def numerical_value(self) -> Union[int, float]:
436 """Get the current value of the tunable as a number."""
437 assert self._current_value is not None
438 if self._type == "int":
439 return int(self._current_value)
440 elif self._type == "float":
441 return float(self._current_value)
442 else:
443 raise ValueError("Cannot get numerical value for a categorical tunable.")
445 @numerical_value.setter
446 def numerical_value(self, new_value: Union[int, float]) -> Union[int, float]:
447 """Set the current numerical value of the tunable."""
448 # We need this coercion for the values produced by some optimizers
449 # (e.g., scikit-optimize) and for data restored from certain storage
450 # systems (where values can be strings).
451 assert self.is_numerical
452 self.value = new_value
453 return self.value
455 @property
456 def name(self) -> str:
457 """Get the name / string ID of the tunable."""
458 return self._name
460 @property
461 def special(self) -> Union[List[int], List[float]]:
462 """
463 Get the special values of the tunable. Return an empty list if there are none.
465 Returns
466 -------
467 special : [int] | [float]
468 A list of special values of the tunable. Can be empty.
469 """
470 return self._special
472 @property
473 def is_special(self) -> bool:
474 """
475 Check if the current value of the tunable is special.
477 Returns
478 -------
479 is_special : bool
480 True if the current value of the tunable is special, False otherwise.
481 """
482 return self.value in self._special
484 @property
485 def weights(self) -> Optional[List[float]]:
486 """
487 Get the weights of the categories or special values of the tunable. Return None
488 if there are none.
490 Returns
491 -------
492 weights : [float]
493 A list of weights or None.
494 """
495 return self._weights
497 @property
498 def range_weight(self) -> Optional[float]:
499 """
500 Get weight of the range of the numeric tunable. Return None if there are no
501 weights or a tunable is categorical.
503 Returns
504 -------
505 weight : float
506 Weight of the range or None.
507 """
508 assert self.is_numerical
509 assert self._special
510 assert self._weights
511 return self._range_weight
513 @property
514 def type(self) -> TunableValueTypeName:
515 """
516 Get the data type of the tunable.
518 Returns
519 -------
520 type : str
521 Data type of the tunable - one of {'int', 'float', 'categorical'}.
522 """
523 return self._type
525 @property
526 def dtype(self) -> TunableValueType:
527 """
528 Get the actual Python data type of the tunable.
530 This is useful for bulk conversions of the input data.
532 Returns
533 -------
534 dtype : type
535 Data type of the tunable - one of {int, float, str}.
536 """
537 return self._DTYPE[self._type]
539 @property
540 def is_categorical(self) -> bool:
541 """
542 Check if the tunable is categorical.
544 Returns
545 -------
546 is_categorical : bool
547 True if the tunable is categorical, False otherwise.
548 """
549 return self._type == "categorical"
551 @property
552 def is_numerical(self) -> bool:
553 """
554 Check if the tunable is an integer or float.
556 Returns
557 -------
558 is_int : bool
559 True if the tunable is an integer or float, False otherwise.
560 """
561 return self._type in {"int", "float"}
563 @property
564 def range(self) -> Union[Tuple[int, int], Tuple[float, float]]:
565 """
566 Get the range of the tunable if it is numerical, None otherwise.
568 Returns
569 -------
570 range : Union[Tuple[int, int], Tuple[float, float]]
571 A 2-tuple of numbers that represents the range of the tunable.
572 Numbers can be int or float, depending on the type of the tunable.
573 """
574 assert self.is_numerical
575 assert self._range is not None
576 return self._range
578 @property
579 def span(self) -> Union[int, float]:
580 """
581 Gets the span of the range.
583 Note: this does not take quantization into account.
585 Returns
586 -------
587 Union[int, float]
588 (max - min) for numerical tunables.
589 """
590 num_range = self.range
591 return num_range[1] - num_range[0]
593 @property
594 def quantization_bins(self) -> Optional[int]:
595 """
596 Get the number of quantization bins, if specified.
598 Returns
599 -------
600 quantization_bins : int | None
601 Number of quantization bins, or None.
602 """
603 if self.is_categorical:
604 return None
605 return self._quantization_bins
607 @property
608 def quantized_values(self) -> Optional[Union[Iterable[int], Iterable[float]]]:
609 """
610 Get a sequence of quantized values for this tunable.
612 Returns
613 -------
614 Optional[Union[Iterable[int], Iterable[float]]]
615 If the Tunable is quantizable, returns a sequence of those elements,
616 else None (e.g., for unquantized float type tunables).
617 """
618 num_range = self.range
619 if self.type == "float":
620 if not self.quantization_bins:
621 return None
622 # Be sure to return python types instead of numpy types.
623 return (
624 float(x)
625 for x in np.linspace(
626 start=num_range[0],
627 stop=num_range[1],
628 num=self.quantization_bins,
629 endpoint=True,
630 )
631 )
632 assert self.type == "int", f"Unhandled tunable type: {self}"
633 return range(
634 int(num_range[0]),
635 int(num_range[1]) + 1,
636 int(self.span / (self.quantization_bins - 1)) if self.quantization_bins else 1,
637 )
639 @property
640 def cardinality(self) -> Optional[int]:
641 """
642 Gets the cardinality of elements in this tunable, or else None. (i.e., when the
643 tunable is continuous float and not quantized).
645 If the tunable has quantization set, this
647 Returns
648 -------
649 cardinality : int
650 Either the number of points in the tunable or else None.
651 """
652 if self.is_categorical:
653 return len(self.categories)
654 if self.quantization_bins:
655 return self.quantization_bins
656 if self.type == "int":
657 return int(self.span) + 1
658 return None
660 @property
661 def is_log(self) -> Optional[bool]:
662 """
663 Check if numeric tunable is log scale.
665 Returns
666 -------
667 log : bool
668 True if numeric tunable is log scale, False if linear.
669 """
670 assert self.is_numerical
671 return self._log
673 @property
674 def distribution(self) -> Optional[DistributionName]:
675 """
676 Get the name of the distribution (uniform, normal, or beta) if specified.
678 Returns
679 -------
680 distribution : str
681 Name of the distribution (uniform, normal, or beta) or None.
682 """
683 return self._distribution
685 @property
686 def distribution_params(self) -> Dict[str, float]:
687 """
688 Get the parameters of the distribution, if specified.
690 Returns
691 -------
692 distribution_params : Dict[str, float]
693 Parameters of the distribution or None.
694 """
695 assert self._distribution is not None
696 return self._distribution_params
698 @property
699 def categories(self) -> List[Optional[str]]:
700 """
701 Get the list of all possible values of a categorical tunable. Return None if the
702 tunable is not categorical.
704 Returns
705 -------
706 values : List[str]
707 List of all possible values of a categorical tunable.
708 """
709 assert self.is_categorical
710 assert self._values is not None
711 return self._values
713 @property
714 def values(self) -> Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]]:
715 """
716 Gets the categories or quantized values for this tunable.
718 Returns
719 -------
720 Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]]
721 Categories or quantized values.
722 """
723 if self.is_categorical:
724 return self.categories
725 assert self.is_numerical
726 return self.quantized_values
728 @property
729 def meta(self) -> Dict[str, Any]:
730 """
731 Get the tunable's metadata.
733 This is a free-form dictionary that can be used to store any additional
734 information about the tunable (e.g., the unit information).
735 """
736 return self._meta