Coverage for mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py: 100%

128 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for grid search mlos_bench optimizer.""" 

6 

7import itertools 

8import math 

9import random 

10from typing import Dict, List 

11 

12import numpy as np 

13import pytest 

14 

15from mlos_bench.environments.status import Status 

16from mlos_bench.optimizers.grid_search_optimizer import GridSearchOptimizer 

17from mlos_bench.tunables.tunable import TunableValue 

18from mlos_bench.tunables.tunable_groups import TunableGroups 

19 

20# pylint: disable=redefined-outer-name 

21 

22 

23@pytest.fixture 

24def grid_search_tunables_config() -> dict: 

25 """Test fixture for grid search optimizer tunables config.""" 

26 return { 

27 "grid": { 

28 "cost": 1, 

29 "params": { 

30 "cat": { 

31 "type": "categorical", 

32 "values": ["a", "b", "c"], 

33 "default": "a", 

34 }, 

35 "int": { 

36 "type": "int", 

37 "range": [1, 3], 

38 "default": 2, 

39 }, 

40 "float": { 

41 "type": "float", 

42 "range": [0, 1], 

43 "default": 0.5, 

44 "quantization_bins": 5, 

45 }, 

46 }, 

47 }, 

48 } 

49 

50 

51@pytest.fixture 

52def grid_search_tunables_grid( 

53 grid_search_tunables: TunableGroups, 

54) -> List[Dict[str, TunableValue]]: 

55 """ 

56 Test fixture for grid from tunable groups. 

57 

58 Used to check that the grids are the same (ignoring order). 

59 """ 

60 tunables_params_values = [ 

61 tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None 

62 ] 

63 tunable_names = tuple( 

64 tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None 

65 ) 

66 return list( 

67 dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values) 

68 ) 

69 

70 

71@pytest.fixture 

72def grid_search_tunables(grid_search_tunables_config: dict) -> TunableGroups: 

73 """Test fixture for grid search optimizer tunables.""" 

74 return TunableGroups(grid_search_tunables_config) 

75 

76 

77@pytest.fixture 

78def grid_search_opt( 

79 grid_search_tunables: TunableGroups, 

80 grid_search_tunables_grid: List[Dict[str, TunableValue]], 

81) -> GridSearchOptimizer: 

82 """Test fixture for grid search optimizer.""" 

83 assert len(grid_search_tunables) == 3 

84 # Test the convergence logic by controlling the number of iterations to be not a 

85 # multiple of the number of elements in the grid. 

86 max_suggestions = len(grid_search_tunables_grid) * 2 - 3 

87 return GridSearchOptimizer( 

88 tunables=grid_search_tunables, 

89 config={ 

90 "max_suggestions": max_suggestions, 

91 "optimization_targets": {"score": "max", "other_score": "min"}, 

92 }, 

93 ) 

94 

95 

96def test_grid_search_grid( 

97 grid_search_opt: GridSearchOptimizer, 

98 grid_search_tunables: TunableGroups, 

99 grid_search_tunables_grid: List[Dict[str, TunableValue]], 

100) -> None: 

101 """Make sure that grid search optimizer initializes and works correctly.""" 

102 # Check the size. 

103 expected_grid_size = math.prod( 

104 tunable.cardinality or np.inf for tunable, _group in grid_search_tunables 

105 ) 

106 assert expected_grid_size > len(grid_search_tunables) 

107 assert len(grid_search_tunables_grid) == expected_grid_size 

108 # Check for specific example configs inclusion. 

109 expected_config_example: Dict[str, TunableValue] = { 

110 "cat": "a", 

111 "int": 2, 

112 "float": 0.75, 

113 } 

114 grid_search_opt_pending_configs = list(grid_search_opt.pending_configs) 

115 assert expected_config_example in grid_search_tunables_grid 

116 assert expected_config_example in grid_search_opt_pending_configs 

117 # Check the rest of the contents. 

118 # Note: ConfigSpace param name vs TunableGroup parameter name order is not 

119 # consistent, so we need to full dict comparison. 

120 assert len(grid_search_opt_pending_configs) == expected_grid_size 

121 assert all(config in grid_search_tunables_grid for config in grid_search_opt_pending_configs) 

122 assert all(config in grid_search_opt_pending_configs for config in grid_search_tunables_grid) 

123 # Order is less relevant to us, so we'll just check that the sets are the same. 

124 # assert grid_search_opt.pending_configs == grid_search_tunables_grid 

125 

126 

127def test_grid_search( 

128 grid_search_opt: GridSearchOptimizer, 

129 grid_search_tunables: TunableGroups, 

130 grid_search_tunables_grid: List[Dict[str, TunableValue]], 

131) -> None: 

132 """Make sure that grid search optimizer initializes and works correctly.""" 

133 score: Dict[str, TunableValue] = {"score": 1.0, "other_score": 2.0} 

134 status = Status.SUCCEEDED 

135 suggestion = grid_search_opt.suggest() 

136 suggestion_dict = suggestion.get_param_values() 

137 default_config = grid_search_tunables.restore_defaults().get_param_values() 

138 

139 # First suggestion should be the defaults. 

140 assert suggestion.get_param_values() == default_config 

141 # But that shouldn't be the first element in the grid search. 

142 assert suggestion_dict != next(iter(grid_search_tunables_grid)) 

143 # The suggestion should no longer be in the pending_configs. 

144 assert suggestion_dict not in grid_search_opt.pending_configs 

145 # But it should be in the suggested_configs now (and the only one). 

146 assert list(grid_search_opt.suggested_configs) == [default_config] 

147 

148 # Register a score for that suggestion. 

149 grid_search_opt.register(suggestion, status, score) 

150 # Now it shouldn't be in the suggested_configs. 

151 assert len(list(grid_search_opt.suggested_configs)) == 0 

152 

153 grid_search_tunables_grid.remove(default_config) 

154 assert default_config not in grid_search_opt.pending_configs 

155 assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) 

156 assert all( 

157 config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid 

158 ) 

159 

160 # The next suggestion should be a different element in the grid search. 

161 suggestion = grid_search_opt.suggest() 

162 suggestion_dict = suggestion.get_param_values() 

163 assert suggestion_dict != default_config 

164 assert suggestion_dict not in grid_search_opt.pending_configs 

165 assert suggestion_dict in grid_search_opt.suggested_configs 

166 grid_search_opt.register(suggestion, status, score) 

167 assert suggestion_dict not in grid_search_opt.pending_configs 

168 assert suggestion_dict not in grid_search_opt.suggested_configs 

169 

170 grid_search_tunables_grid.remove(suggestion.get_param_values()) 

171 assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) 

172 assert all( 

173 config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid 

174 ) 

175 

176 # We consider not_converged as either having reached "max_suggestions" or an empty grid? 

177 

178 # Try to empty the rest of the grid. 

179 while grid_search_opt.not_converged(): 

180 suggestion = grid_search_opt.suggest() 

181 grid_search_opt.register(suggestion, status, score) 

182 

183 # The grid search should be empty now. 

184 assert not list(grid_search_opt.pending_configs) 

185 assert not list(grid_search_opt.suggested_configs) 

186 assert not grid_search_opt.not_converged() 

187 

188 # But if we still have iterations left, we should be able to suggest again by 

189 # refilling the grid. 

190 assert grid_search_opt.current_iteration < grid_search_opt.max_suggestions 

191 assert grid_search_opt.suggest() 

192 assert list(grid_search_opt.pending_configs) 

193 assert list(grid_search_opt.suggested_configs) 

194 assert grid_search_opt.not_converged() 

195 

196 # Try to finish the rest of our iterations by repeating the grid. 

197 while grid_search_opt.not_converged(): 

198 suggestion = grid_search_opt.suggest() 

199 grid_search_opt.register(suggestion, status, score) 

200 assert not grid_search_opt.not_converged() 

201 assert grid_search_opt.current_iteration >= grid_search_opt.max_suggestions 

202 assert list(grid_search_opt.pending_configs) 

203 assert list(grid_search_opt.suggested_configs) 

204 

205 

206def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: 

207 """Make sure that grid search optimizer works correctly when suggest and register 

208 are called out of order. 

209 """ 

210 # pylint: disable=too-many-locals 

211 score: Dict[str, TunableValue] = {"score": 1.0, "other_score": 2.0} 

212 status = Status.SUCCEEDED 

213 suggest_count = 10 

214 suggested = [grid_search_opt.suggest() for _ in range(suggest_count)] 

215 suggested_shuffled = suggested.copy() 

216 # Try to ensure the shuffled list is different. 

217 for _ in range(3): 

218 random.shuffle(suggested_shuffled) 

219 if suggested_shuffled != suggested: 

220 break 

221 assert suggested != suggested_shuffled 

222 

223 for suggestion in suggested_shuffled: 

224 suggestion_dict = suggestion.get_param_values() 

225 assert suggestion_dict not in grid_search_opt.pending_configs 

226 assert suggestion_dict in grid_search_opt.suggested_configs 

227 grid_search_opt.register(suggestion, status, score) 

228 assert suggestion_dict not in grid_search_opt.suggested_configs 

229 

230 best_score, best_config = grid_search_opt.get_best_observation() 

231 assert best_score == score 

232 

233 # test re-register with higher score 

234 best_suggestion = suggested_shuffled[0] 

235 best_suggestion_dict = best_suggestion.get_param_values() 

236 assert best_suggestion_dict not in grid_search_opt.pending_configs 

237 assert best_suggestion_dict not in grid_search_opt.suggested_configs 

238 

239 best_suggestion_score: Dict[str, TunableValue] = {} 

240 for opt_target, opt_dir in grid_search_opt.targets.items(): 

241 val = score[opt_target] 

242 assert isinstance(val, (int, float)) 

243 best_suggestion_score[opt_target] = val - 1 if opt_dir == "min" else val + 1 

244 

245 grid_search_opt.register(best_suggestion, status, best_suggestion_score) 

246 assert best_suggestion_dict not in grid_search_opt.suggested_configs 

247 

248 best_score, best_config = grid_search_opt.get_best_observation() 

249 assert best_score == best_suggestion_score 

250 assert best_config == best_suggestion 

251 

252 # Check bulk register 

253 suggested = [grid_search_opt.suggest() for _ in range(suggest_count)] 

254 assert all( 

255 suggestion.get_param_values() not in grid_search_opt.pending_configs 

256 for suggestion in suggested 

257 ) 

258 assert all( 

259 suggestion.get_param_values() in grid_search_opt.suggested_configs 

260 for suggestion in suggested 

261 ) 

262 

263 # Those new suggestions also shouldn't be in the set of previously suggested configs. 

264 assert all(suggestion.get_param_values() not in suggested_shuffled for suggestion in suggested) 

265 

266 grid_search_opt.bulk_register( 

267 [suggestion.get_param_values() for suggestion in suggested], 

268 [score] * len(suggested), 

269 [status] * len(suggested), 

270 ) 

271 

272 assert all( 

273 suggestion.get_param_values() not in grid_search_opt.pending_configs 

274 for suggestion in suggested 

275 ) 

276 assert all( 

277 suggestion.get_param_values() not in grid_search_opt.suggested_configs 

278 for suggestion in suggested 

279 ) 

280 

281 best_score, best_config = grid_search_opt.get_best_observation() 

282 assert best_score == best_suggestion_score 

283 assert best_config == best_suggestion 

284 

285 

286def test_grid_search_register( 

287 grid_search_opt: GridSearchOptimizer, 

288 grid_search_tunables: TunableGroups, 

289) -> None: 

290 """Make sure that the `.register()` method adjusts the score signs correctly.""" 

291 assert grid_search_opt.register( 

292 grid_search_tunables, 

293 Status.SUCCEEDED, 

294 { 

295 "score": 1.0, 

296 "other_score": 2.0, 

297 }, 

298 ) == { 

299 "score": -1.0, # max 

300 "other_score": 2.0, # min 

301 } 

302 

303 assert grid_search_opt.register(grid_search_tunables, Status.FAILED) == { 

304 "score": float("inf"), 

305 "other_score": float("inf"), 

306 }