Coverage for mlos_core/mlos_core/tests/spaces/monkey_patch_quantization_test.py: 100%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for ConfigSpace quantization monkey patching.""" 

6 

7import numpy as np 

8from ConfigSpace import ( 

9 ConfigurationSpace, 

10 UniformFloatHyperparameter, 

11 UniformIntegerHyperparameter, 

12) 

13from numpy.random import RandomState 

14 

15from mlos_core.spaces.converters.util import ( 

16 QUANTIZATION_BINS_META_KEY, 

17 monkey_patch_cs_quantization, 

18 monkey_patch_hp_quantization, 

19) 

20from mlos_core.tests import SEED 

21 

22 

23def test_configspace_quant_int() -> None: 

24 """Check the quantization of an integer hyperparameter.""" 

25 quantization_bins = 11 

26 quantized_values = set(range(0, 101, 10)) 

27 hp = UniformIntegerHyperparameter( 

28 "hp", 

29 lower=0, 

30 upper=100, 

31 log=False, 

32 meta={QUANTIZATION_BINS_META_KEY: quantization_bins}, 

33 ) 

34 

35 # Before patching: expect that at least one value is not quantized. 

36 assert not set(hp.sample_value(100)).issubset(quantized_values) 

37 

38 monkey_patch_hp_quantization(hp) 

39 # After patching: *all* values must belong to the set of quantized values. 

40 assert hp.sample_value() in quantized_values # check scalar type 

41 assert set(hp.sample_value(100)).issubset(quantized_values) # batch version 

42 

43 

44def test_configspace_quant_float() -> None: 

45 """Check the quantization of a float hyperparameter.""" 

46 # 5 is a nice number of bins to avoid floating point errors. 

47 quantization_bins = 5 

48 quantized_values = set(np.linspace(0, 1, num=quantization_bins, endpoint=True)) 

49 hp = UniformFloatHyperparameter( 

50 "hp", 

51 lower=0, 

52 upper=1, 

53 log=False, 

54 meta={QUANTIZATION_BINS_META_KEY: quantization_bins}, 

55 ) 

56 

57 # Before patching: expect that at least one value is not quantized. 

58 assert not set(hp.sample_value(100)).issubset(quantized_values) 

59 

60 monkey_patch_hp_quantization(hp) 

61 # After patching: *all* values must belong to the set of quantized values. 

62 assert hp.sample_value() in quantized_values # check scalar type 

63 assert set(hp.sample_value(100)).issubset(quantized_values) # batch version 

64 

65 

66def test_configspace_quant_repatch() -> None: 

67 """Repatch the same hyperparameter with different number of bins.""" 

68 quantization_bins = 11 

69 quantized_values = set(range(0, 101, 10)) 

70 hp = UniformIntegerHyperparameter( 

71 "hp", 

72 lower=0, 

73 upper=100, 

74 log=False, 

75 meta={QUANTIZATION_BINS_META_KEY: quantization_bins}, 

76 ) 

77 

78 # Before patching: expect that at least one value is not quantized. 

79 assert not set(hp.sample_value(100)).issubset(quantized_values) 

80 

81 monkey_patch_hp_quantization(hp) 

82 # After patching: *all* values must belong to the set of quantized values. 

83 samples = hp.sample_value(100, seed=RandomState(SEED)) 

84 assert set(samples).issubset(quantized_values) 

85 

86 # Patch the same hyperparameter again and check that the results are the same. 

87 monkey_patch_hp_quantization(hp) 

88 # After patching: *all* values must belong to the set of quantized values. 

89 assert all(samples == hp.sample_value(100, seed=RandomState(SEED))) 

90 

91 # Repatch with the higher number of bins and make sure we get new values. 

92 new_meta = dict(hp.meta or {}) 

93 new_meta[QUANTIZATION_BINS_META_KEY] = 21 

94 hp.meta = new_meta 

95 monkey_patch_hp_quantization(hp) 

96 samples_set = set(hp.sample_value(100, seed=RandomState(SEED))) 

97 quantized_values_new = set(range(5, 96, 10)) 

98 assert samples_set.issubset(set(range(0, 101, 5))) 

99 assert len(samples_set - quantized_values_new) < len(samples_set) 

100 

101 # Repatch without quantization and make sure we get the original values. 

102 new_meta = dict(hp.meta or {}) 

103 del new_meta[QUANTIZATION_BINS_META_KEY] 

104 hp.meta = new_meta 

105 assert hp.meta.get(QUANTIZATION_BINS_META_KEY) is None 

106 monkey_patch_hp_quantization(hp) 

107 samples_set = set(hp.sample_value(100, seed=RandomState(SEED))) 

108 assert samples_set.issubset(set(range(0, 101))) 

109 assert len(quantized_values_new) < len(quantized_values) < len(samples_set) 

110 

111 

112def test_configspace_quant() -> None: 

113 """Test quantization of multiple hyperparameters in the ConfigSpace.""" 

114 space = ConfigurationSpace( 

115 name="cs_test", 

116 space={ 

117 "hp_int": (0, 100000), 

118 "hp_int_quant": (0, 100000), 

119 "hp_float": (0.0, 1.0), 

120 "hp_categorical": ["a", "b", "c"], 

121 "hp_constant": 1337, 

122 }, 

123 ) 

124 space["hp_int_quant"].meta = {QUANTIZATION_BINS_META_KEY: 5} 

125 space["hp_float"].meta = {QUANTIZATION_BINS_META_KEY: 11} 

126 monkey_patch_cs_quantization(space) 

127 

128 space.seed(SEED) 

129 assert dict(space.sample_configuration()) == { 

130 "hp_categorical": "c", 

131 "hp_constant": 1337, 

132 "hp_float": 0.6, 

133 "hp_int": 60263, 

134 "hp_int_quant": 0, 

135 } 

136 assert [dict(conf) for conf in space.sample_configuration(3)] == [ 

137 { 

138 "hp_categorical": "a", 

139 "hp_constant": 1337, 

140 "hp_float": 0.4, 

141 "hp_int": 59150, 

142 "hp_int_quant": 50000, 

143 }, 

144 { 

145 "hp_categorical": "a", 

146 "hp_constant": 1337, 

147 "hp_float": 0.3, 

148 "hp_int": 65725, 

149 "hp_int_quant": 75000, 

150 }, 

151 { 

152 "hp_categorical": "b", 

153 "hp_constant": 1337, 

154 "hp_float": 0.6, 

155 "hp_int": 84654, 

156 "hp_int_quant": 25000, 

157 }, 

158 ]