Coverage for mlos_core/mlos_core/spaces/converters/util.py: 90%
30 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Helper functions for config space converters."""
7from ConfigSpace import ConfigurationSpace
8from ConfigSpace.functional import quantize
9from ConfigSpace.hyperparameters import Hyperparameter, NumericalHyperparameter
11QUANTIZATION_BINS_META_KEY = "quantization_bins"
14def monkey_patch_hp_quantization(hp: Hyperparameter) -> Hyperparameter:
15 """
16 Monkey-patch quantization into the Hyperparameter.
18 Temporary workaround to dropped quantization support in ConfigSpace 1.0
20 Notes
21 -----
22 See <https://github.com/automl/ConfigSpace/issues/390>.
24 Parameters
25 ----------
26 hp : ConfigSpace.hyperparameters.Hyperparameter
27 ConfigSpace hyperparameter to patch.
29 Returns
30 -------
31 hp : ConfigSpace.hyperparameters.Hyperparameter
32 Patched hyperparameter.
33 """
34 if not isinstance(hp, NumericalHyperparameter):
35 return hp
37 assert isinstance(hp, NumericalHyperparameter)
38 dist = hp._vector_dist # pylint: disable=protected-access
39 quantization_bins = (hp.meta or {}).get(QUANTIZATION_BINS_META_KEY)
40 if quantization_bins is None:
41 # No quantization requested.
42 # Remove any previously applied patches.
43 if hasattr(dist, "sample_vector_mlos_orig"):
44 setattr(
45 dist,
46 "sample_vector",
47 dist.sample_vector_mlos_orig, # pyright: ignore[reportAttributeAccessIssue]
48 )
49 delattr(dist, "sample_vector_mlos_orig")
50 return hp
52 try:
53 quantization_bins = int(quantization_bins)
54 except ValueError as ex:
55 raise ValueError(f"{quantization_bins=} :: must be an integer.") from ex
57 if quantization_bins <= 1:
58 raise ValueError(f"{quantization_bins=} :: must be greater than 1.")
60 if not hasattr(dist, "sample_vector_mlos_orig"):
61 setattr(dist, "sample_vector_mlos_orig", dist.sample_vector)
63 assert hasattr(dist, "sample_vector_mlos_orig")
64 setattr(
65 dist,
66 "sample_vector",
67 lambda n, *, seed=None: quantize(
68 dist.sample_vector_mlos_orig( # pyright: ignore[reportAttributeAccessIssue]
69 n,
70 seed=seed,
71 ),
72 bounds=(dist.lower_vectorized, dist.upper_vectorized),
73 bins=quantization_bins,
74 ),
75 )
76 return hp
79def monkey_patch_cs_quantization(cs: ConfigurationSpace) -> ConfigurationSpace:
80 """
81 Monkey-patch quantization into the Hyperparameters of a ConfigSpace.
83 Parameters
84 ----------
85 cs : ConfigSpace.ConfigurationSpace
86 ConfigSpace to patch.
88 Returns
89 -------
90 cs : ConfigSpace.ConfigurationSpace
91 Patched ConfigSpace.
92 """
93 for hp in cs.values():
94 monkey_patch_hp_quantization(hp)
95 return cs