Coverage for mlos_core/mlos_core/spaces/converters/util.py: 90%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Helper functions for config space converters.""" 

6 

7from ConfigSpace import ConfigurationSpace 

8from ConfigSpace.functional import quantize 

9from ConfigSpace.hyperparameters import Hyperparameter, NumericalHyperparameter 

10 

11QUANTIZATION_BINS_META_KEY = "quantization_bins" 

12 

13 

14def monkey_patch_hp_quantization(hp: Hyperparameter) -> Hyperparameter: 

15 """ 

16 Monkey-patch quantization into the Hyperparameter. 

17 

18 Temporary workaround to dropped quantization support in ConfigSpace 1.0 

19 See Also: <https://github.com/automl/ConfigSpace/issues/390> 

20 

21 Parameters 

22 ---------- 

23 hp : Hyperparameter 

24 ConfigSpace hyperparameter to patch. 

25 

26 Returns 

27 ------- 

28 hp : Hyperparameter 

29 Patched hyperparameter. 

30 """ 

31 if not isinstance(hp, NumericalHyperparameter): 

32 return hp 

33 

34 assert isinstance(hp, NumericalHyperparameter) 

35 dist = hp._vector_dist # pylint: disable=protected-access 

36 quantization_bins = (hp.meta or {}).get(QUANTIZATION_BINS_META_KEY) 

37 if quantization_bins is None: 

38 # No quantization requested. 

39 # Remove any previously applied patches. 

40 if hasattr(dist, "sample_vector_mlos_orig"): 

41 setattr(dist, "sample_vector", dist.sample_vector_mlos_orig) 

42 delattr(dist, "sample_vector_mlos_orig") 

43 return hp 

44 

45 try: 

46 quantization_bins = int(quantization_bins) 

47 except ValueError as ex: 

48 raise ValueError(f"{quantization_bins=} :: must be an integer.") from ex 

49 

50 if quantization_bins <= 1: 

51 raise ValueError(f"{quantization_bins=} :: must be greater than 1.") 

52 

53 if not hasattr(dist, "sample_vector_mlos_orig"): 

54 setattr(dist, "sample_vector_mlos_orig", dist.sample_vector) 

55 

56 assert hasattr(dist, "sample_vector_mlos_orig") 

57 setattr( 

58 dist, 

59 "sample_vector", 

60 lambda n, *, seed=None: quantize( 

61 dist.sample_vector_mlos_orig(n, seed=seed), 

62 bounds=(dist.lower_vectorized, dist.upper_vectorized), 

63 bins=quantization_bins, 

64 ), 

65 ) 

66 return hp 

67 

68 

69def monkey_patch_cs_quantization(cs: ConfigurationSpace) -> ConfigurationSpace: 

70 """ 

71 Monkey-patch quantization into the Hyperparameters of a ConfigSpace. 

72 

73 Parameters 

74 ---------- 

75 cs : ConfigurationSpace 

76 ConfigSpace to patch. 

77 

78 Returns 

79 ------- 

80 cs : ConfigurationSpace 

81 Patched ConfigSpace. 

82 """ 

83 for hp in cs.values(): 

84 monkey_patch_hp_quantization(hp) 

85 return cs