Coverage for mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py: 100%

127 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for grid search mlos_bench optimizer.""" 

6 

7import itertools 

8import math 

9import random 

10 

11import numpy as np 

12import pytest 

13 

14from mlos_bench.environments.status import Status 

15from mlos_bench.optimizers.grid_search_optimizer import GridSearchOptimizer 

16from mlos_bench.tunables.tunable import TunableValue 

17from mlos_bench.tunables.tunable_groups import TunableGroups 

18 

19# pylint: disable=redefined-outer-name 

20 

21 

22@pytest.fixture 

23def grid_search_tunables_config() -> dict: 

24 """Test fixture for grid search optimizer tunables config.""" 

25 return { 

26 "grid": { 

27 "cost": 1, 

28 "params": { 

29 "cat": { 

30 "type": "categorical", 

31 "values": ["a", "b", "c"], 

32 "default": "a", 

33 }, 

34 "int": { 

35 "type": "int", 

36 "range": [1, 3], 

37 "default": 2, 

38 }, 

39 "float": { 

40 "type": "float", 

41 "range": [0, 1], 

42 "default": 0.5, 

43 "quantization_bins": 5, 

44 }, 

45 }, 

46 }, 

47 } 

48 

49 

50@pytest.fixture 

51def grid_search_tunables_grid( 

52 grid_search_tunables: TunableGroups, 

53) -> list[dict[str, TunableValue]]: 

54 """ 

55 Test fixture for grid from tunable groups. 

56 

57 Used to check that the grids are the same (ignoring order). 

58 """ 

59 tunables_params_values = [ 

60 tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None 

61 ] 

62 tunable_names = tuple( 

63 tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None 

64 ) 

65 return list( 

66 dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values) 

67 ) 

68 

69 

70@pytest.fixture 

71def grid_search_tunables(grid_search_tunables_config: dict) -> TunableGroups: 

72 """Test fixture for grid search optimizer tunables.""" 

73 return TunableGroups(grid_search_tunables_config) 

74 

75 

76@pytest.fixture 

77def grid_search_opt( 

78 grid_search_tunables: TunableGroups, 

79 grid_search_tunables_grid: list[dict[str, TunableValue]], 

80) -> GridSearchOptimizer: 

81 """Test fixture for grid search optimizer.""" 

82 assert len(grid_search_tunables) == 3 

83 # Test the convergence logic by controlling the number of iterations to be not a 

84 # multiple of the number of elements in the grid. 

85 max_suggestions = len(grid_search_tunables_grid) * 2 - 3 

86 return GridSearchOptimizer( 

87 tunables=grid_search_tunables, 

88 config={ 

89 "max_suggestions": max_suggestions, 

90 "optimization_targets": {"score": "max", "other_score": "min"}, 

91 }, 

92 ) 

93 

94 

95def test_grid_search_grid( 

96 grid_search_opt: GridSearchOptimizer, 

97 grid_search_tunables: TunableGroups, 

98 grid_search_tunables_grid: list[dict[str, TunableValue]], 

99) -> None: 

100 """Make sure that grid search optimizer initializes and works correctly.""" 

101 # Check the size. 

102 expected_grid_size = math.prod( 

103 tunable.cardinality or np.inf for tunable, _group in grid_search_tunables 

104 ) 

105 assert expected_grid_size > len(grid_search_tunables) 

106 assert len(grid_search_tunables_grid) == expected_grid_size 

107 # Check for specific example configs inclusion. 

108 expected_config_example: dict[str, TunableValue] = { 

109 "cat": "a", 

110 "int": 2, 

111 "float": 0.75, 

112 } 

113 grid_search_opt_pending_configs = list(grid_search_opt.pending_configs) 

114 assert expected_config_example in grid_search_tunables_grid 

115 assert expected_config_example in grid_search_opt_pending_configs 

116 # Check the rest of the contents. 

117 # Note: ConfigSpace param name vs TunableGroup parameter name order is not 

118 # consistent, so we need to full dict comparison. 

119 assert len(grid_search_opt_pending_configs) == expected_grid_size 

120 assert all(config in grid_search_tunables_grid for config in grid_search_opt_pending_configs) 

121 assert all(config in grid_search_opt_pending_configs for config in grid_search_tunables_grid) 

122 # Order is less relevant to us, so we'll just check that the sets are the same. 

123 # assert grid_search_opt.pending_configs == grid_search_tunables_grid 

124 

125 

126def test_grid_search( 

127 grid_search_opt: GridSearchOptimizer, 

128 grid_search_tunables: TunableGroups, 

129 grid_search_tunables_grid: list[dict[str, TunableValue]], 

130) -> None: 

131 """Make sure that grid search optimizer initializes and works correctly.""" 

132 score: dict[str, TunableValue] = {"score": 1.0, "other_score": 2.0} 

133 status = Status.SUCCEEDED 

134 suggestion = grid_search_opt.suggest() 

135 suggestion_dict = suggestion.get_param_values() 

136 default_config = grid_search_tunables.restore_defaults().get_param_values() 

137 

138 # First suggestion should be the defaults. 

139 assert suggestion.get_param_values() == default_config 

140 # But that shouldn't be the first element in the grid search. 

141 assert suggestion_dict != next(iter(grid_search_tunables_grid)) 

142 # The suggestion should no longer be in the pending_configs. 

143 assert suggestion_dict not in grid_search_opt.pending_configs 

144 # But it should be in the suggested_configs now (and the only one). 

145 assert list(grid_search_opt.suggested_configs) == [default_config] 

146 

147 # Register a score for that suggestion. 

148 grid_search_opt.register(suggestion, status, score) 

149 # Now it shouldn't be in the suggested_configs. 

150 assert len(list(grid_search_opt.suggested_configs)) == 0 

151 

152 grid_search_tunables_grid.remove(default_config) 

153 assert default_config not in grid_search_opt.pending_configs 

154 assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) 

155 assert all( 

156 config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid 

157 ) 

158 

159 # The next suggestion should be a different element in the grid search. 

160 suggestion = grid_search_opt.suggest() 

161 suggestion_dict = suggestion.get_param_values() 

162 assert suggestion_dict != default_config 

163 assert suggestion_dict not in grid_search_opt.pending_configs 

164 assert suggestion_dict in grid_search_opt.suggested_configs 

165 grid_search_opt.register(suggestion, status, score) 

166 assert suggestion_dict not in grid_search_opt.pending_configs 

167 assert suggestion_dict not in grid_search_opt.suggested_configs 

168 

169 grid_search_tunables_grid.remove(suggestion.get_param_values()) 

170 assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) 

171 assert all( 

172 config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid 

173 ) 

174 

175 # We consider not_converged as either having reached "max_suggestions" or an empty grid? 

176 

177 # Try to empty the rest of the grid. 

178 while grid_search_opt.not_converged(): 

179 suggestion = grid_search_opt.suggest() 

180 grid_search_opt.register(suggestion, status, score) 

181 

182 # The grid search should be empty now. 

183 assert not list(grid_search_opt.pending_configs) 

184 assert not list(grid_search_opt.suggested_configs) 

185 assert not grid_search_opt.not_converged() 

186 

187 # But if we still have iterations left, we should be able to suggest again by 

188 # refilling the grid. 

189 assert grid_search_opt.current_iteration < grid_search_opt.max_suggestions 

190 assert grid_search_opt.suggest() 

191 assert list(grid_search_opt.pending_configs) 

192 assert list(grid_search_opt.suggested_configs) 

193 assert grid_search_opt.not_converged() 

194 

195 # Try to finish the rest of our iterations by repeating the grid. 

196 while grid_search_opt.not_converged(): 

197 suggestion = grid_search_opt.suggest() 

198 grid_search_opt.register(suggestion, status, score) 

199 assert not grid_search_opt.not_converged() 

200 assert grid_search_opt.current_iteration >= grid_search_opt.max_suggestions 

201 assert list(grid_search_opt.pending_configs) 

202 assert list(grid_search_opt.suggested_configs) 

203 

204 

205def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: 

206 """Make sure that grid search optimizer works correctly when suggest and register 

207 are called out of order. 

208 """ 

209 # pylint: disable=too-many-locals 

210 score: dict[str, TunableValue] = {"score": 1.0, "other_score": 2.0} 

211 status = Status.SUCCEEDED 

212 suggest_count = 10 

213 suggested = [grid_search_opt.suggest() for _ in range(suggest_count)] 

214 suggested_shuffled = suggested.copy() 

215 # Try to ensure the shuffled list is different. 

216 for _ in range(3): 

217 random.shuffle(suggested_shuffled) 

218 if suggested_shuffled != suggested: 

219 break 

220 assert suggested != suggested_shuffled 

221 

222 for suggestion in suggested_shuffled: 

223 suggestion_dict = suggestion.get_param_values() 

224 assert suggestion_dict not in grid_search_opt.pending_configs 

225 assert suggestion_dict in grid_search_opt.suggested_configs 

226 grid_search_opt.register(suggestion, status, score) 

227 assert suggestion_dict not in grid_search_opt.suggested_configs 

228 

229 best_score, best_config = grid_search_opt.get_best_observation() 

230 assert best_score == score 

231 

232 # test re-register with higher score 

233 best_suggestion = suggested_shuffled[0] 

234 best_suggestion_dict = best_suggestion.get_param_values() 

235 assert best_suggestion_dict not in grid_search_opt.pending_configs 

236 assert best_suggestion_dict not in grid_search_opt.suggested_configs 

237 

238 best_suggestion_score: dict[str, TunableValue] = {} 

239 for opt_target, opt_dir in grid_search_opt.targets.items(): 

240 val = score[opt_target] 

241 assert isinstance(val, (int, float)) 

242 best_suggestion_score[opt_target] = val - 1 if opt_dir == "min" else val + 1 

243 

244 grid_search_opt.register(best_suggestion, status, best_suggestion_score) 

245 assert best_suggestion_dict not in grid_search_opt.suggested_configs 

246 

247 best_score, best_config = grid_search_opt.get_best_observation() 

248 assert best_score == best_suggestion_score 

249 assert best_config == best_suggestion 

250 

251 # Check bulk register 

252 suggested = [grid_search_opt.suggest() for _ in range(suggest_count)] 

253 assert all( 

254 suggestion.get_param_values() not in grid_search_opt.pending_configs 

255 for suggestion in suggested 

256 ) 

257 assert all( 

258 suggestion.get_param_values() in grid_search_opt.suggested_configs 

259 for suggestion in suggested 

260 ) 

261 

262 # Those new suggestions also shouldn't be in the set of previously suggested configs. 

263 assert all(suggestion.get_param_values() not in suggested_shuffled for suggestion in suggested) 

264 

265 grid_search_opt.bulk_register( 

266 [suggestion.get_param_values() for suggestion in suggested], 

267 [score] * len(suggested), 

268 [status] * len(suggested), 

269 ) 

270 

271 assert all( 

272 suggestion.get_param_values() not in grid_search_opt.pending_configs 

273 for suggestion in suggested 

274 ) 

275 assert all( 

276 suggestion.get_param_values() not in grid_search_opt.suggested_configs 

277 for suggestion in suggested 

278 ) 

279 

280 best_score, best_config = grid_search_opt.get_best_observation() 

281 assert best_score == best_suggestion_score 

282 assert best_config == best_suggestion 

283 

284 

285def test_grid_search_register( 

286 grid_search_opt: GridSearchOptimizer, 

287 grid_search_tunables: TunableGroups, 

288) -> None: 

289 """Make sure that the `.register()` method adjusts the score signs correctly.""" 

290 assert grid_search_opt.register( 

291 grid_search_tunables, 

292 Status.SUCCEEDED, 

293 { 

294 "score": 1.0, 

295 "other_score": 2.0, 

296 }, 

297 ) == { 

298 "score": -1.0, # max 

299 "other_score": 2.0, # min 

300 } 

301 

302 assert grid_search_opt.register(grid_search_tunables, Status.FAILED) == { 

303 "score": float("inf"), 

304 "other_score": float("inf"), 

305 }