Coverage for mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py: 100%
127 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for grid search mlos_bench optimizer."""
7import itertools
8import math
9import random
11import numpy as np
12import pytest
14from mlos_bench.environments.status import Status
15from mlos_bench.optimizers.grid_search_optimizer import GridSearchOptimizer
16from mlos_bench.tunables.tunable import TunableValue
17from mlos_bench.tunables.tunable_groups import TunableGroups
19# pylint: disable=redefined-outer-name
22@pytest.fixture
23def grid_search_tunables_config() -> dict:
24 """Test fixture for grid search optimizer tunables config."""
25 return {
26 "grid": {
27 "cost": 1,
28 "params": {
29 "cat": {
30 "type": "categorical",
31 "values": ["a", "b", "c"],
32 "default": "a",
33 },
34 "int": {
35 "type": "int",
36 "range": [1, 3],
37 "default": 2,
38 },
39 "float": {
40 "type": "float",
41 "range": [0, 1],
42 "default": 0.5,
43 "quantization_bins": 5,
44 },
45 },
46 },
47 }
50@pytest.fixture
51def grid_search_tunables_grid(
52 grid_search_tunables: TunableGroups,
53) -> list[dict[str, TunableValue]]:
54 """
55 Test fixture for grid from tunable groups.
57 Used to check that the grids are the same (ignoring order).
58 """
59 tunables_params_values = [
60 tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None
61 ]
62 tunable_names = tuple(
63 tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None
64 )
65 return list(
66 dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values)
67 )
70@pytest.fixture
71def grid_search_tunables(grid_search_tunables_config: dict) -> TunableGroups:
72 """Test fixture for grid search optimizer tunables."""
73 return TunableGroups(grid_search_tunables_config)
76@pytest.fixture
77def grid_search_opt(
78 grid_search_tunables: TunableGroups,
79 grid_search_tunables_grid: list[dict[str, TunableValue]],
80) -> GridSearchOptimizer:
81 """Test fixture for grid search optimizer."""
82 assert len(grid_search_tunables) == 3
83 # Test the convergence logic by controlling the number of iterations to be not a
84 # multiple of the number of elements in the grid.
85 max_suggestions = len(grid_search_tunables_grid) * 2 - 3
86 return GridSearchOptimizer(
87 tunables=grid_search_tunables,
88 config={
89 "max_suggestions": max_suggestions,
90 "optimization_targets": {"score": "max", "other_score": "min"},
91 },
92 )
95def test_grid_search_grid(
96 grid_search_opt: GridSearchOptimizer,
97 grid_search_tunables: TunableGroups,
98 grid_search_tunables_grid: list[dict[str, TunableValue]],
99) -> None:
100 """Make sure that grid search optimizer initializes and works correctly."""
101 # Check the size.
102 expected_grid_size = math.prod(
103 tunable.cardinality or np.inf for tunable, _group in grid_search_tunables
104 )
105 assert expected_grid_size > len(grid_search_tunables)
106 assert len(grid_search_tunables_grid) == expected_grid_size
107 # Check for specific example configs inclusion.
108 expected_config_example: dict[str, TunableValue] = {
109 "cat": "a",
110 "int": 2,
111 "float": 0.75,
112 }
113 grid_search_opt_pending_configs = list(grid_search_opt.pending_configs)
114 assert expected_config_example in grid_search_tunables_grid
115 assert expected_config_example in grid_search_opt_pending_configs
116 # Check the rest of the contents.
117 # Note: ConfigSpace param name vs TunableGroup parameter name order is not
118 # consistent, so we need to full dict comparison.
119 assert len(grid_search_opt_pending_configs) == expected_grid_size
120 assert all(config in grid_search_tunables_grid for config in grid_search_opt_pending_configs)
121 assert all(config in grid_search_opt_pending_configs for config in grid_search_tunables_grid)
122 # Order is less relevant to us, so we'll just check that the sets are the same.
123 # assert grid_search_opt.pending_configs == grid_search_tunables_grid
126def test_grid_search(
127 grid_search_opt: GridSearchOptimizer,
128 grid_search_tunables: TunableGroups,
129 grid_search_tunables_grid: list[dict[str, TunableValue]],
130) -> None:
131 """Make sure that grid search optimizer initializes and works correctly."""
132 score: dict[str, TunableValue] = {"score": 1.0, "other_score": 2.0}
133 status = Status.SUCCEEDED
134 suggestion = grid_search_opt.suggest()
135 suggestion_dict = suggestion.get_param_values()
136 default_config = grid_search_tunables.restore_defaults().get_param_values()
138 # First suggestion should be the defaults.
139 assert suggestion.get_param_values() == default_config
140 # But that shouldn't be the first element in the grid search.
141 assert suggestion_dict != next(iter(grid_search_tunables_grid))
142 # The suggestion should no longer be in the pending_configs.
143 assert suggestion_dict not in grid_search_opt.pending_configs
144 # But it should be in the suggested_configs now (and the only one).
145 assert list(grid_search_opt.suggested_configs) == [default_config]
147 # Register a score for that suggestion.
148 grid_search_opt.register(suggestion, status, score)
149 # Now it shouldn't be in the suggested_configs.
150 assert len(list(grid_search_opt.suggested_configs)) == 0
152 grid_search_tunables_grid.remove(default_config)
153 assert default_config not in grid_search_opt.pending_configs
154 assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs)
155 assert all(
156 config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid
157 )
159 # The next suggestion should be a different element in the grid search.
160 suggestion = grid_search_opt.suggest()
161 suggestion_dict = suggestion.get_param_values()
162 assert suggestion_dict != default_config
163 assert suggestion_dict not in grid_search_opt.pending_configs
164 assert suggestion_dict in grid_search_opt.suggested_configs
165 grid_search_opt.register(suggestion, status, score)
166 assert suggestion_dict not in grid_search_opt.pending_configs
167 assert suggestion_dict not in grid_search_opt.suggested_configs
169 grid_search_tunables_grid.remove(suggestion.get_param_values())
170 assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs)
171 assert all(
172 config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid
173 )
175 # We consider not_converged as either having reached "max_suggestions" or an empty grid?
177 # Try to empty the rest of the grid.
178 while grid_search_opt.not_converged():
179 suggestion = grid_search_opt.suggest()
180 grid_search_opt.register(suggestion, status, score)
182 # The grid search should be empty now.
183 assert not list(grid_search_opt.pending_configs)
184 assert not list(grid_search_opt.suggested_configs)
185 assert not grid_search_opt.not_converged()
187 # But if we still have iterations left, we should be able to suggest again by
188 # refilling the grid.
189 assert grid_search_opt.current_iteration < grid_search_opt.max_suggestions
190 assert grid_search_opt.suggest()
191 assert list(grid_search_opt.pending_configs)
192 assert list(grid_search_opt.suggested_configs)
193 assert grid_search_opt.not_converged()
195 # Try to finish the rest of our iterations by repeating the grid.
196 while grid_search_opt.not_converged():
197 suggestion = grid_search_opt.suggest()
198 grid_search_opt.register(suggestion, status, score)
199 assert not grid_search_opt.not_converged()
200 assert grid_search_opt.current_iteration >= grid_search_opt.max_suggestions
201 assert list(grid_search_opt.pending_configs)
202 assert list(grid_search_opt.suggested_configs)
205def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None:
206 """Make sure that grid search optimizer works correctly when suggest and register
207 are called out of order.
208 """
209 # pylint: disable=too-many-locals
210 score: dict[str, TunableValue] = {"score": 1.0, "other_score": 2.0}
211 status = Status.SUCCEEDED
212 suggest_count = 10
213 suggested = [grid_search_opt.suggest() for _ in range(suggest_count)]
214 suggested_shuffled = suggested.copy()
215 # Try to ensure the shuffled list is different.
216 for _ in range(3):
217 random.shuffle(suggested_shuffled)
218 if suggested_shuffled != suggested:
219 break
220 assert suggested != suggested_shuffled
222 for suggestion in suggested_shuffled:
223 suggestion_dict = suggestion.get_param_values()
224 assert suggestion_dict not in grid_search_opt.pending_configs
225 assert suggestion_dict in grid_search_opt.suggested_configs
226 grid_search_opt.register(suggestion, status, score)
227 assert suggestion_dict not in grid_search_opt.suggested_configs
229 best_score, best_config = grid_search_opt.get_best_observation()
230 assert best_score == score
232 # test re-register with higher score
233 best_suggestion = suggested_shuffled[0]
234 best_suggestion_dict = best_suggestion.get_param_values()
235 assert best_suggestion_dict not in grid_search_opt.pending_configs
236 assert best_suggestion_dict not in grid_search_opt.suggested_configs
238 best_suggestion_score: dict[str, TunableValue] = {}
239 for opt_target, opt_dir in grid_search_opt.targets.items():
240 val = score[opt_target]
241 assert isinstance(val, (int, float))
242 best_suggestion_score[opt_target] = val - 1 if opt_dir == "min" else val + 1
244 grid_search_opt.register(best_suggestion, status, best_suggestion_score)
245 assert best_suggestion_dict not in grid_search_opt.suggested_configs
247 best_score, best_config = grid_search_opt.get_best_observation()
248 assert best_score == best_suggestion_score
249 assert best_config == best_suggestion
251 # Check bulk register
252 suggested = [grid_search_opt.suggest() for _ in range(suggest_count)]
253 assert all(
254 suggestion.get_param_values() not in grid_search_opt.pending_configs
255 for suggestion in suggested
256 )
257 assert all(
258 suggestion.get_param_values() in grid_search_opt.suggested_configs
259 for suggestion in suggested
260 )
262 # Those new suggestions also shouldn't be in the set of previously suggested configs.
263 assert all(suggestion.get_param_values() not in suggested_shuffled for suggestion in suggested)
265 grid_search_opt.bulk_register(
266 [suggestion.get_param_values() for suggestion in suggested],
267 [score] * len(suggested),
268 [status] * len(suggested),
269 )
271 assert all(
272 suggestion.get_param_values() not in grid_search_opt.pending_configs
273 for suggestion in suggested
274 )
275 assert all(
276 suggestion.get_param_values() not in grid_search_opt.suggested_configs
277 for suggestion in suggested
278 )
280 best_score, best_config = grid_search_opt.get_best_observation()
281 assert best_score == best_suggestion_score
282 assert best_config == best_suggestion
285def test_grid_search_register(
286 grid_search_opt: GridSearchOptimizer,
287 grid_search_tunables: TunableGroups,
288) -> None:
289 """Make sure that the `.register()` method adjusts the score signs correctly."""
290 assert grid_search_opt.register(
291 grid_search_tunables,
292 Status.SUCCEEDED,
293 {
294 "score": 1.0,
295 "other_score": 2.0,
296 },
297 ) == {
298 "score": -1.0, # max
299 "other_score": 2.0, # min
300 }
302 assert grid_search_opt.register(grid_search_tunables, Status.FAILED) == {
303 "score": float("inf"),
304 "other_score": float("inf"),
305 }