Coverage for mlos_core/mlos_core/spaces/converters/util.py: 90%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Helper functions for config space converters.""" 

6 

7from ConfigSpace import ConfigurationSpace 

8from ConfigSpace.functional import quantize 

9from ConfigSpace.hyperparameters import Hyperparameter, NumericalHyperparameter 

10 

11QUANTIZATION_BINS_META_KEY = "quantization_bins" 

12 

13 

14def monkey_patch_hp_quantization(hp: Hyperparameter) -> Hyperparameter: 

15 """ 

16 Monkey-patch quantization into the Hyperparameter. 

17 

18 Temporary workaround to dropped quantization support in ConfigSpace 1.0 

19 

20 Notes 

21 ----- 

22 See <https://github.com/automl/ConfigSpace/issues/390>. 

23 

24 Parameters 

25 ---------- 

26 hp : ConfigSpace.hyperparameters.Hyperparameter 

27 ConfigSpace hyperparameter to patch. 

28 

29 Returns 

30 ------- 

31 hp : ConfigSpace.hyperparameters.Hyperparameter 

32 Patched hyperparameter. 

33 """ 

34 if not isinstance(hp, NumericalHyperparameter): 

35 return hp 

36 

37 assert isinstance(hp, NumericalHyperparameter) 

38 dist = hp._vector_dist # pylint: disable=protected-access 

39 quantization_bins = (hp.meta or {}).get(QUANTIZATION_BINS_META_KEY) 

40 if quantization_bins is None: 

41 # No quantization requested. 

42 # Remove any previously applied patches. 

43 if hasattr(dist, "sample_vector_mlos_orig"): 

44 setattr( 

45 dist, 

46 "sample_vector", 

47 dist.sample_vector_mlos_orig, # pyright: ignore[reportAttributeAccessIssue] 

48 ) 

49 delattr(dist, "sample_vector_mlos_orig") 

50 return hp 

51 

52 try: 

53 quantization_bins = int(quantization_bins) 

54 except ValueError as ex: 

55 raise ValueError(f"{quantization_bins=} :: must be an integer.") from ex 

56 

57 if quantization_bins <= 1: 

58 raise ValueError(f"{quantization_bins=} :: must be greater than 1.") 

59 

60 if not hasattr(dist, "sample_vector_mlos_orig"): 

61 setattr(dist, "sample_vector_mlos_orig", dist.sample_vector) 

62 

63 assert hasattr(dist, "sample_vector_mlos_orig") 

64 setattr( 

65 dist, 

66 "sample_vector", 

67 lambda n, *, seed=None: quantize( 

68 dist.sample_vector_mlos_orig( # pyright: ignore[reportAttributeAccessIssue] 

69 n, 

70 seed=seed, 

71 ), 

72 bounds=(dist.lower_vectorized, dist.upper_vectorized), 

73 bins=quantization_bins, 

74 ), 

75 ) 

76 return hp 

77 

78 

79def monkey_patch_cs_quantization(cs: ConfigurationSpace) -> ConfigurationSpace: 

80 """ 

81 Monkey-patch quantization into the Hyperparameters of a ConfigSpace. 

82 

83 Parameters 

84 ---------- 

85 cs : ConfigSpace.ConfigurationSpace 

86 ConfigSpace to patch. 

87 

88 Returns 

89 ------- 

90 cs : ConfigSpace.ConfigurationSpace 

91 Patched ConfigSpace. 

92 """ 

93 for hp in cs.values(): 

94 monkey_patch_hp_quantization(hp) 

95 return cs