Coverage for mlos_bench/mlos_bench/tunables/tunable.py: 96%
309 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Tunable parameter definition."""
6import copy
7import logging
8from collections.abc import Iterable, Sequence
9from typing import Any, Literal, TypedDict
11import numpy as np
13from mlos_bench.util import nullable
15_LOG = logging.getLogger(__name__)
17TunableValue = int | float | str | None
18"""A tunable parameter value type alias."""
20TunableValueType = type[int] | type[float] | type[str]
21"""Tunable value type."""
23TunableValueTypeTuple = (int, float, str, type(None))
24"""
25Tunable value type tuple.
27For checking with isinstance()
28"""
30TunableValueTypeName = Literal["int", "float", "categorical"]
31"""The string name of a tunable value type."""
33TunableValuesDict = dict[str, TunableValue]
34"""Tunable values dictionary type."""
36DistributionName = Literal["uniform", "normal", "beta"]
37"""Tunable value distribution type."""
40class DistributionDictOpt(TypedDict, total=False):
41 """
42 A TypedDict for a :py:class:`.Tunable` parameter's optional ``distribution``'s
43 config parameters.
45 Mostly used by type checking. These are the types expected to be received from the
46 json config.
47 """
49 params: dict[str, float] | None
52class DistributionDict(DistributionDictOpt):
53 """
54 A TypedDict for a :py:class:`.Tunable` parameter's required ``distribution``'s
55 config parameters.
57 Mostly used by type checking. These are the types expected to be received from the
58 json config.
59 """
61 type: DistributionName
64class TunableDictOpt(TypedDict, total=False):
65 """
66 A TypedDict for a :py:class:`.Tunable` parameter's optional config parameters.
68 Mostly used for mypy type checking. These are the types expected to be received from
69 the json config.
70 """
72 # Optional fields
73 description: str | None
74 values: list[str | None] | None
75 range: Sequence[int] | Sequence[float] | None
76 quantization_bins: int | None
77 log: bool | None
78 distribution: DistributionDict | None
79 special: list[int] | list[float] | None
80 values_weights: list[float] | None
81 special_weights: list[float] | None
82 range_weight: float | None
83 meta: dict[str, Any]
86class TunableDict(TunableDictOpt):
87 """
88 A TypedDict for a :py:class:`.Tunable` parameter's required config parameters.
90 Mostly used for mypy type checking. These are the types expected to be received from
91 the json config.
92 """
94 # Required fields
95 type: TunableValueTypeName
96 default: TunableValue
99def tunable_dict_from_dict(config: dict[str, Any]) -> TunableDict:
100 """
101 Creates a TunableDict from a regular dict.
103 Parameters
104 ----------
105 config : dict[str, Any]
106 A regular dict that represents a TunableDict.
108 Returns
109 -------
110 TunableDict
111 """
112 _type = config.get("type")
113 if _type not in Tunable.DTYPE:
114 raise ValueError(f"Invalid parameter type: {_type}")
115 _meta = config.get("meta", {})
116 return TunableDict(
117 type=_type,
118 description=config.get("description"),
119 default=config.get("default"),
120 values=config.get("values"),
121 range=config.get("range"),
122 quantization_bins=config.get("quantization_bins"),
123 log=config.get("log"),
124 distribution=config.get("distribution"),
125 special=config.get("special"),
126 values_weights=config.get("values_weights"),
127 special_weights=config.get("special_weights"),
128 range_weight=config.get("range_weight"),
129 meta=_meta,
130 )
133class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-methods
134 """A tunable parameter definition and its current value."""
136 DTYPE: dict[TunableValueTypeName, TunableValueType] = {
137 "int": int,
138 "float": float,
139 "categorical": str,
140 }
141 """Maps Tunable types to their corresponding Python types by name."""
143 def __init__(self, name: str, config: dict):
144 """
145 Create an instance of a new tunable parameter.
147 Parameters
148 ----------
149 name : str
150 Human-readable identifier of the tunable parameter.
151 config : dict
152 Python dict that represents a Tunable (e.g., deserialized from JSON)
154 See Also
155 --------
156 :py:mod:`mlos_bench.tunables` : for more information on tunable parameters and
157 their configuration.
158 """
159 t_config = tunable_dict_from_dict(config)
160 if not isinstance(name, str) or "!" in name: # TODO: Use a regex here and in JSON schema
161 raise ValueError(f"Invalid name of the tunable: {name}")
162 self._name = name
163 self._type: TunableValueTypeName = t_config["type"] # required
164 if self._type not in self.DTYPE:
165 raise ValueError(f"Invalid parameter type: {self._type}")
166 self._description = t_config.get("description")
167 self._default = t_config["default"]
168 self._default = self.dtype(self._default) if self._default is not None else self._default
169 self._values = t_config.get("values")
170 if self._values:
171 self._values = [str(v) if v is not None else v for v in self._values]
172 self._meta: dict[str, Any] = t_config.get("meta", {})
173 self._range: tuple[int, int] | tuple[float, float] | None = None
174 self._quantization_bins: int | None = t_config.get("quantization_bins")
175 self._log: bool | None = t_config.get("log")
176 self._distribution: DistributionName | None = None
177 self._distribution_params: dict[str, float] = {}
178 distr = t_config.get("distribution")
179 if distr:
180 self._distribution = distr["type"] # required
181 self._distribution_params = distr.get("params") or {}
182 config_range = config.get("range")
183 if config_range is not None:
184 assert len(config_range) == 2, f"Invalid range: {config_range}"
185 config_range = (config_range[0], config_range[1])
186 self._range = config_range
187 self._special: list[int] | list[float] = t_config.get("special") or []
188 self._weights: list[float] = (
189 t_config.get("values_weights") or t_config.get("special_weights") or []
190 )
191 self._range_weight: float | None = t_config.get("range_weight")
192 self._current_value = None
193 self._sanity_check()
194 self.value = self._default
196 def _sanity_check(self) -> None:
197 """Check if the status of the Tunable is valid, and throw ValueError if it is
198 not.
199 """
200 if self.is_categorical:
201 self._sanity_check_categorical()
202 elif self.is_numerical:
203 self._sanity_check_numerical()
204 else:
205 raise ValueError(f"Invalid parameter type for tunable {self}: {self._type}")
206 if not self.is_valid(self.default):
207 raise ValueError(f"Invalid default value for tunable {self}: {self.default}")
209 def _sanity_check_categorical(self) -> None:
210 """Check if the status of the categorical Tunable is valid, and throw ValueError
211 if it is not.
212 """
213 # pylint: disable=too-complex
214 assert self.is_categorical
215 if not (self._values and isinstance(self._values, Iterable)):
216 raise ValueError(f"Must specify values for the categorical type tunable {self}")
217 if self._range is not None:
218 raise ValueError(f"Range must be None for the categorical type tunable {self}")
219 if len(set(self._values)) != len(self._values):
220 raise ValueError(f"Values must be unique for the categorical type tunable {self}")
221 if self._special:
222 raise ValueError(f"Categorical tunable cannot have special values: {self}")
223 if self._range_weight is not None:
224 raise ValueError(f"Categorical tunable cannot have range_weight: {self}")
225 if self._log is not None:
226 raise ValueError(f"Categorical tunable cannot have log parameter: {self}")
227 if self._quantization_bins is not None:
228 raise ValueError(f"Categorical tunable cannot have quantization parameter: {self}")
229 if self._distribution is not None:
230 raise ValueError(f"Categorical parameters do not support `distribution`: {self}")
231 if self._weights:
232 if len(self._weights) != len(self._values):
233 raise ValueError(f"Must specify weights for all values: {self}")
234 if any(w < 0 for w in self._weights):
235 raise ValueError(f"All weights must be non-negative: {self}")
237 def _sanity_check_numerical(self) -> None:
238 """Check if the status of the numerical Tunable is valid, and throw ValueError
239 if it is not.
240 """
241 # pylint: disable=too-complex,too-many-branches
242 assert self.is_numerical
243 if self._values is not None:
244 raise ValueError(f"Values must be None for the numerical type tunable {self}")
245 if not self._range or len(self._range) != 2 or self._range[0] >= self._range[1]:
246 raise ValueError(f"Invalid range for tunable {self}: {self._range}")
247 if self._quantization_bins is not None and self._quantization_bins <= 1:
248 raise ValueError(f"Number of quantization bins is <= 1: {self}")
249 if self._distribution is not None and self._distribution not in {
250 "uniform",
251 "normal",
252 "beta",
253 }:
254 raise ValueError(f"Invalid distribution: {self}")
255 if self._distribution_params and self._distribution is None:
256 raise ValueError(f"Must specify the distribution: {self}")
257 if self._weights:
258 if self._range_weight is None:
259 raise ValueError(f"Must specify weight for the range: {self}")
260 if len(self._weights) != len(self._special):
261 raise ValueError("Must specify weights for all special values {self}")
262 if any(w < 0 for w in self._weights + [self._range_weight]):
263 raise ValueError(f"All weights must be non-negative: {self}")
264 elif self._range_weight is not None:
265 raise ValueError(f"Must specify both weights and range_weight or none: {self}")
267 def __repr__(self) -> str:
268 """
269 Produce a human-readable version of the Tunable (mostly for logging).
271 Returns
272 -------
273 string : str
274 A human-readable version of the Tunable.
275 """
276 # TODO? Add weights, specials, quantization, distribution?
277 if self.is_categorical:
278 return (
279 f"{self._name}[{self._type}]({self._values}:{self._default})={self._current_value}"
280 )
281 return f"{self._name}[{self._type}]({self._range}:{self._default})={self._current_value}"
283 def __eq__(self, other: object) -> bool:
284 """
285 Check if two Tunable objects are equal.
287 Parameters
288 ----------
289 other : Tunable
290 A tunable object to compare to.
292 Returns
293 -------
294 is_equal : bool
295 True if the Tunables correspond to the same parameter and have the same value and type.
296 NOTE: ranges and special values are not currently considered in the comparison.
297 """
298 if not isinstance(other, Tunable):
299 return False
300 return bool(
301 self._name == other._name
302 and self._type == other._type
303 and self._current_value == other._current_value
304 )
306 def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements
307 """
308 Compare the two Tunable objects. We mostly need this to create a canonical list
309 of tunable objects when hashing a TunableGroup.
311 Parameters
312 ----------
313 other : Tunable
314 A tunable object to compare to.
316 Returns
317 -------
318 is_less : bool
319 True if the current Tunable is less then the other one, False otherwise.
320 """
321 if not isinstance(other, Tunable):
322 return False
323 if self._name < other._name:
324 return True
325 if self._name == other._name and self._type < other._type:
326 return True
327 if self._name == other._name and self._type == other._type:
328 if self.is_numerical:
329 assert self._current_value is not None
330 assert other._current_value is not None
331 return bool(float(self._current_value) < float(other._current_value))
332 # else: categorical
333 if self._current_value is None:
334 return True
335 if other._current_value is None:
336 return False
337 return bool(str(self._current_value) < str(other._current_value))
338 return False
340 def copy(self) -> "Tunable":
341 """
342 Deep copy of the Tunable object.
344 Returns
345 -------
346 tunable : Tunable
347 A new Tunable object that is a deep copy of the original one.
348 """
349 return copy.deepcopy(self)
351 @property
352 def default(self) -> TunableValue:
353 """Get the default value of the tunable."""
354 return self._default
356 def is_default(self) -> TunableValue:
357 """Checks whether the currently assigned value of the tunable is at its
358 default.
359 """
360 return self._default == self._current_value
362 @property
363 def value(self) -> TunableValue:
364 """Get the current value of the tunable."""
365 return self._current_value
367 @value.setter
368 def value(self, value: TunableValue) -> TunableValue:
369 """Set the current value of the tunable."""
370 # We need this coercion for the values produced by some optimizers
371 # (e.g., scikit-optimize) and for data restored from certain storage
372 # systems (where values can be strings).
373 try:
374 if self.is_categorical and value is None:
375 coerced_value = None
376 else:
377 assert value is not None
378 coerced_value = self.dtype(value)
379 except Exception:
380 _LOG.error(
381 "Impossible conversion: %s %s <- %s %s",
382 self._type,
383 self._name,
384 type(value),
385 value,
386 )
387 raise
389 if self._type == "int" and isinstance(value, float) and value != coerced_value:
390 _LOG.error(
391 "Loss of precision: %s %s <- %s %s",
392 self._type,
393 self._name,
394 type(value),
395 value,
396 )
397 raise ValueError(f"Loss of precision: {self._name}={value}")
399 if not self.is_valid(coerced_value):
400 _LOG.error(
401 "Invalid assignment: %s %s <- %s %s",
402 self._type,
403 self._name,
404 type(value),
405 value,
406 )
407 raise ValueError(f"Invalid value for the tunable: {self._name}={value}")
409 self._current_value = coerced_value
410 return self._current_value
412 def update(self, value: TunableValue) -> bool:
413 """
414 Assign the value to the tunable. Return True if it is a new value, False
415 otherwise.
417 Parameters
418 ----------
419 value : int | float | str
420 Value to assign.
422 Returns
423 -------
424 is_updated : bool
425 True if the new value is different from the previous one, False otherwise.
426 """
427 prev_value = self._current_value
428 self.value = value
429 return prev_value != self._current_value
431 def is_valid(self, value: TunableValue) -> bool:
432 """
433 Check if the value can be assigned to the tunable.
435 Parameters
436 ----------
437 value : int | float | str
438 Value to validate.
440 Returns
441 -------
442 is_valid : bool
443 True if the value is valid, False otherwise.
444 """
445 if self.is_categorical and self._values:
446 return value in self._values
447 elif self.is_numerical and self._range:
448 if isinstance(value, (int, float)):
449 return self.in_range(value) or value in self._special
450 else:
451 raise ValueError(f"Invalid value type for tunable {self}: {value}={type(value)}")
452 else:
453 raise ValueError(f"Invalid parameter type: {self._type}")
455 def in_range(self, value: int | float | str | None) -> bool:
456 """
457 Check if the value is within the range of the tunable.
459 Do *NOT* check for special values. Return False if the tunable or value is
460 categorical or None.
461 """
462 return (
463 isinstance(value, (float, int))
464 and self.is_numerical
465 and self._range is not None
466 and bool(self._range[0] <= value <= self._range[1])
467 )
469 @property
470 def category(self) -> str | None:
471 """Get the current value of the tunable as a string."""
472 if self.is_categorical:
473 return nullable(str, self._current_value)
474 else:
475 raise ValueError("Cannot get categorical values for a numerical tunable.")
477 @category.setter
478 def category(self, new_value: str | None) -> str | None:
479 """Set the current value of the tunable."""
480 assert self.is_categorical
481 assert isinstance(new_value, (str, type(None)))
482 self.value = new_value
483 return self.value
485 @property
486 def numerical_value(self) -> int | float:
487 """Get the current value of the tunable as a number."""
488 assert self._current_value is not None
489 if self._type == "int":
490 return int(self._current_value)
491 elif self._type == "float":
492 return float(self._current_value)
493 else:
494 raise ValueError("Cannot get numerical value for a categorical tunable.")
496 @numerical_value.setter
497 def numerical_value(self, new_value: int | float) -> int | float:
498 """Set the current numerical value of the tunable."""
499 # We need this coercion for the values produced by some optimizers
500 # (e.g., scikit-optimize) and for data restored from certain storage
501 # systems (where values can be strings).
502 assert self.is_numerical
503 self.value = new_value
504 return self.value
506 @property
507 def name(self) -> str:
508 """Get the name / string ID of the tunable."""
509 return self._name
511 @property
512 def special(self) -> list[int] | list[float]:
513 """
514 Get the special values of the tunable. Return an empty list if there are none.
516 Returns
517 -------
518 special : [int] | [float]
519 A list of special values of the tunable. Can be empty.
520 """
521 return self._special
523 @property
524 def is_special(self) -> bool:
525 """
526 Check if the current value of the tunable is special.
528 Returns
529 -------
530 is_special : bool
531 True if the current value of the tunable is special, False otherwise.
532 """
533 return self.value in self._special
535 @property
536 def weights(self) -> list[float] | None:
537 """
538 Get the weights of the categories or special values of the tunable. Return None
539 if there are none.
541 Returns
542 -------
543 weights : [float]
544 A list of weights or None.
545 """
546 return self._weights
548 @property
549 def range_weight(self) -> float | None:
550 """
551 Get weight of the range of the numeric tunable. Return None if there are no
552 weights or a tunable is categorical.
554 Returns
555 -------
556 weight : float
557 Weight of the range or None.
558 """
559 assert self.is_numerical
560 assert self._special
561 assert self._weights
562 return self._range_weight
564 @property
565 def type(self) -> TunableValueTypeName:
566 """
567 Get the data type of the tunable.
569 Returns
570 -------
571 type : str
572 Data type of the tunable - one of {'int', 'float', 'categorical'}.
573 """
574 return self._type
576 @property
577 def dtype(self) -> TunableValueType:
578 """
579 Get the actual Python data type of the tunable.
581 This is useful for bulk conversions of the input data.
583 Returns
584 -------
585 dtype : type
586 Data type of the tunable - one of {int, float, str}.
587 """
588 return self.DTYPE[self._type]
590 @property
591 def is_categorical(self) -> bool:
592 """
593 Check if the tunable is categorical.
595 Returns
596 -------
597 is_categorical : bool
598 True if the tunable is categorical, False otherwise.
599 """
600 return self._type == "categorical"
602 @property
603 def is_numerical(self) -> bool:
604 """
605 Check if the tunable is an integer or float.
607 Returns
608 -------
609 is_int : bool
610 True if the tunable is an integer or float, False otherwise.
611 """
612 return self._type in {"int", "float"}
614 @property
615 def range(self) -> tuple[int, int] | tuple[float, float]:
616 """
617 Get the range of the tunable if it is numerical, None otherwise.
619 Returns
620 -------
621 range : Union[tuple[int, int], tuple[float, float]]
622 A 2-tuple of numbers that represents the range of the tunable.
623 Numbers can be int or float, depending on the type of the tunable.
624 """
625 assert self.is_numerical
626 assert self._range is not None
627 return self._range
629 @property
630 def span(self) -> int | float:
631 """
632 Gets the span of the range.
634 Note: this does not take quantization into account.
636 Returns
637 -------
638 Union[int, float]
639 (max - min) for numerical tunables.
640 """
641 num_range = self.range
642 return num_range[1] - num_range[0]
644 @property
645 def quantization_bins(self) -> int | None:
646 """
647 Get the number of quantization bins, if specified.
649 Returns
650 -------
651 quantization_bins : int | None
652 Number of quantization bins, or None.
653 """
654 if self.is_categorical:
655 return None
656 return self._quantization_bins
658 @property
659 def quantized_values(self) -> Iterable[int] | Iterable[float] | None:
660 """
661 Get a sequence of quantized values for this tunable.
663 Returns
664 -------
665 Optional[Union[Iterable[int], Iterable[float]]]
666 If the Tunable is quantizable, returns a sequence of those elements,
667 else None (e.g., for unquantized float type tunables).
668 """
669 num_range = self.range
670 if self.type == "float":
671 if not self.quantization_bins:
672 return None
673 # Be sure to return python types instead of numpy types.
674 return (
675 float(x)
676 for x in np.linspace(
677 start=num_range[0],
678 stop=num_range[1],
679 num=self.quantization_bins,
680 endpoint=True,
681 )
682 )
683 assert self.type == "int", f"Unhandled tunable type: {self}"
684 return range(
685 int(num_range[0]),
686 int(num_range[1]) + 1,
687 int(self.span / (self.quantization_bins - 1)) if self.quantization_bins else 1,
688 )
690 @property
691 def cardinality(self) -> int | None:
692 """
693 Gets the cardinality of elements in this tunable, or else None. (i.e., when the
694 tunable is continuous float and not quantized).
696 If the tunable has quantization set, this
698 Returns
699 -------
700 cardinality : int
701 Either the number of points in the tunable or else None.
702 """
703 if self.is_categorical:
704 return len(self.categories)
705 if self.quantization_bins:
706 return self.quantization_bins
707 if self.type == "int":
708 return int(self.span) + 1
709 return None
711 @property
712 def is_log(self) -> bool | None:
713 """
714 Check if numeric tunable is log scale.
716 Returns
717 -------
718 log : bool
719 True if numeric tunable is log scale, False if linear.
720 """
721 assert self.is_numerical
722 return self._log
724 @property
725 def distribution(self) -> DistributionName | None:
726 """
727 Get the name of the distribution (uniform, normal, or beta) if specified.
729 Returns
730 -------
731 distribution : str
732 Name of the distribution (uniform, normal, or beta) or None.
733 """
734 return self._distribution
736 @property
737 def distribution_params(self) -> dict[str, float]:
738 """
739 Get the parameters of the distribution, if specified.
741 Returns
742 -------
743 distribution_params : dict[str, float]
744 Parameters of the distribution or None.
745 """
746 assert self._distribution is not None
747 return self._distribution_params
749 @property
750 def categories(self) -> list[str | None]:
751 """
752 Get the list of all possible values of a categorical tunable. Return None if the
753 tunable is not categorical.
755 Returns
756 -------
757 values : list[str]
758 List of all possible values of a categorical tunable.
759 """
760 assert self.is_categorical
761 assert self._values is not None
762 return self._values
764 @property
765 def values(self) -> Iterable[str | None] | Iterable[int] | Iterable[float] | None:
766 """
767 Gets the categories or quantized values for this tunable.
769 Returns
770 -------
771 Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]]
772 Categories or quantized values.
773 """
774 if self.is_categorical:
775 return self.categories
776 assert self.is_numerical
777 return self.quantized_values
779 @property
780 def meta(self) -> dict[str, Any]:
781 """
782 Get the tunable's metadata.
784 This is a free-form dictionary that can be used to store any additional
785 information about the tunable (e.g., the unit information).
786 """
787 return self._meta