Coverage for mlos_core/mlos_core/spaces/converters/util.py: 90%
30 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Helper functions for config space converters."""
7from ConfigSpace import ConfigurationSpace
8from ConfigSpace.functional import quantize
9from ConfigSpace.hyperparameters import Hyperparameter, NumericalHyperparameter
11QUANTIZATION_BINS_META_KEY = "quantization_bins"
14def monkey_patch_hp_quantization(hp: Hyperparameter) -> Hyperparameter:
15 """
16 Monkey-patch quantization into the Hyperparameter.
18 Temporary workaround to dropped quantization support in ConfigSpace 1.0
19 See Also: <https://github.com/automl/ConfigSpace/issues/390>
21 Parameters
22 ----------
23 hp : Hyperparameter
24 ConfigSpace hyperparameter to patch.
26 Returns
27 -------
28 hp : Hyperparameter
29 Patched hyperparameter.
30 """
31 if not isinstance(hp, NumericalHyperparameter):
32 return hp
34 assert isinstance(hp, NumericalHyperparameter)
35 dist = hp._vector_dist # pylint: disable=protected-access
36 quantization_bins = (hp.meta or {}).get(QUANTIZATION_BINS_META_KEY)
37 if quantization_bins is None:
38 # No quantization requested.
39 # Remove any previously applied patches.
40 if hasattr(dist, "sample_vector_mlos_orig"):
41 setattr(dist, "sample_vector", dist.sample_vector_mlos_orig)
42 delattr(dist, "sample_vector_mlos_orig")
43 return hp
45 try:
46 quantization_bins = int(quantization_bins)
47 except ValueError as ex:
48 raise ValueError(f"{quantization_bins=} :: must be an integer.") from ex
50 if quantization_bins <= 1:
51 raise ValueError(f"{quantization_bins=} :: must be greater than 1.")
53 if not hasattr(dist, "sample_vector_mlos_orig"):
54 setattr(dist, "sample_vector_mlos_orig", dist.sample_vector)
56 assert hasattr(dist, "sample_vector_mlos_orig")
57 setattr(
58 dist,
59 "sample_vector",
60 lambda n, *, seed=None: quantize(
61 dist.sample_vector_mlos_orig(n, seed=seed),
62 bounds=(dist.lower_vectorized, dist.upper_vectorized),
63 bins=quantization_bins,
64 ),
65 )
66 return hp
69def monkey_patch_cs_quantization(cs: ConfigurationSpace) -> ConfigurationSpace:
70 """
71 Monkey-patch quantization into the Hyperparameters of a ConfigSpace.
73 Parameters
74 ----------
75 cs : ConfigurationSpace
76 ConfigSpace to patch.
78 Returns
79 -------
80 cs : ConfigurationSpace
81 Patched ConfigSpace.
82 """
83 for hp in cs.values():
84 monkey_patch_hp_quantization(hp)
85 return cs