Coverage for mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py: 100%
128 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for grid search mlos_bench optimizer."""
7import itertools
8import math
9import random
10from typing import Dict, List
12import numpy as np
13import pytest
15from mlos_bench.environments.status import Status
16from mlos_bench.optimizers.grid_search_optimizer import GridSearchOptimizer
17from mlos_bench.tunables.tunable import TunableValue
18from mlos_bench.tunables.tunable_groups import TunableGroups
20# pylint: disable=redefined-outer-name
23@pytest.fixture
24def grid_search_tunables_config() -> dict:
25 """Test fixture for grid search optimizer tunables config."""
26 return {
27 "grid": {
28 "cost": 1,
29 "params": {
30 "cat": {
31 "type": "categorical",
32 "values": ["a", "b", "c"],
33 "default": "a",
34 },
35 "int": {
36 "type": "int",
37 "range": [1, 3],
38 "default": 2,
39 },
40 "float": {
41 "type": "float",
42 "range": [0, 1],
43 "default": 0.5,
44 "quantization_bins": 5,
45 },
46 },
47 },
48 }
51@pytest.fixture
52def grid_search_tunables_grid(
53 grid_search_tunables: TunableGroups,
54) -> List[Dict[str, TunableValue]]:
55 """
56 Test fixture for grid from tunable groups.
58 Used to check that the grids are the same (ignoring order).
59 """
60 tunables_params_values = [
61 tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None
62 ]
63 tunable_names = tuple(
64 tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None
65 )
66 return list(
67 dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values)
68 )
71@pytest.fixture
72def grid_search_tunables(grid_search_tunables_config: dict) -> TunableGroups:
73 """Test fixture for grid search optimizer tunables."""
74 return TunableGroups(grid_search_tunables_config)
77@pytest.fixture
78def grid_search_opt(
79 grid_search_tunables: TunableGroups,
80 grid_search_tunables_grid: List[Dict[str, TunableValue]],
81) -> GridSearchOptimizer:
82 """Test fixture for grid search optimizer."""
83 assert len(grid_search_tunables) == 3
84 # Test the convergence logic by controlling the number of iterations to be not a
85 # multiple of the number of elements in the grid.
86 max_suggestions = len(grid_search_tunables_grid) * 2 - 3
87 return GridSearchOptimizer(
88 tunables=grid_search_tunables,
89 config={
90 "max_suggestions": max_suggestions,
91 "optimization_targets": {"score": "max", "other_score": "min"},
92 },
93 )
96def test_grid_search_grid(
97 grid_search_opt: GridSearchOptimizer,
98 grid_search_tunables: TunableGroups,
99 grid_search_tunables_grid: List[Dict[str, TunableValue]],
100) -> None:
101 """Make sure that grid search optimizer initializes and works correctly."""
102 # Check the size.
103 expected_grid_size = math.prod(
104 tunable.cardinality or np.inf for tunable, _group in grid_search_tunables
105 )
106 assert expected_grid_size > len(grid_search_tunables)
107 assert len(grid_search_tunables_grid) == expected_grid_size
108 # Check for specific example configs inclusion.
109 expected_config_example: Dict[str, TunableValue] = {
110 "cat": "a",
111 "int": 2,
112 "float": 0.75,
113 }
114 grid_search_opt_pending_configs = list(grid_search_opt.pending_configs)
115 assert expected_config_example in grid_search_tunables_grid
116 assert expected_config_example in grid_search_opt_pending_configs
117 # Check the rest of the contents.
118 # Note: ConfigSpace param name vs TunableGroup parameter name order is not
119 # consistent, so we need to full dict comparison.
120 assert len(grid_search_opt_pending_configs) == expected_grid_size
121 assert all(config in grid_search_tunables_grid for config in grid_search_opt_pending_configs)
122 assert all(config in grid_search_opt_pending_configs for config in grid_search_tunables_grid)
123 # Order is less relevant to us, so we'll just check that the sets are the same.
124 # assert grid_search_opt.pending_configs == grid_search_tunables_grid
127def test_grid_search(
128 grid_search_opt: GridSearchOptimizer,
129 grid_search_tunables: TunableGroups,
130 grid_search_tunables_grid: List[Dict[str, TunableValue]],
131) -> None:
132 """Make sure that grid search optimizer initializes and works correctly."""
133 score: Dict[str, TunableValue] = {"score": 1.0, "other_score": 2.0}
134 status = Status.SUCCEEDED
135 suggestion = grid_search_opt.suggest()
136 suggestion_dict = suggestion.get_param_values()
137 default_config = grid_search_tunables.restore_defaults().get_param_values()
139 # First suggestion should be the defaults.
140 assert suggestion.get_param_values() == default_config
141 # But that shouldn't be the first element in the grid search.
142 assert suggestion_dict != next(iter(grid_search_tunables_grid))
143 # The suggestion should no longer be in the pending_configs.
144 assert suggestion_dict not in grid_search_opt.pending_configs
145 # But it should be in the suggested_configs now (and the only one).
146 assert list(grid_search_opt.suggested_configs) == [default_config]
148 # Register a score for that suggestion.
149 grid_search_opt.register(suggestion, status, score)
150 # Now it shouldn't be in the suggested_configs.
151 assert len(list(grid_search_opt.suggested_configs)) == 0
153 grid_search_tunables_grid.remove(default_config)
154 assert default_config not in grid_search_opt.pending_configs
155 assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs)
156 assert all(
157 config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid
158 )
160 # The next suggestion should be a different element in the grid search.
161 suggestion = grid_search_opt.suggest()
162 suggestion_dict = suggestion.get_param_values()
163 assert suggestion_dict != default_config
164 assert suggestion_dict not in grid_search_opt.pending_configs
165 assert suggestion_dict in grid_search_opt.suggested_configs
166 grid_search_opt.register(suggestion, status, score)
167 assert suggestion_dict not in grid_search_opt.pending_configs
168 assert suggestion_dict not in grid_search_opt.suggested_configs
170 grid_search_tunables_grid.remove(suggestion.get_param_values())
171 assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs)
172 assert all(
173 config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid
174 )
176 # We consider not_converged as either having reached "max_suggestions" or an empty grid?
178 # Try to empty the rest of the grid.
179 while grid_search_opt.not_converged():
180 suggestion = grid_search_opt.suggest()
181 grid_search_opt.register(suggestion, status, score)
183 # The grid search should be empty now.
184 assert not list(grid_search_opt.pending_configs)
185 assert not list(grid_search_opt.suggested_configs)
186 assert not grid_search_opt.not_converged()
188 # But if we still have iterations left, we should be able to suggest again by
189 # refilling the grid.
190 assert grid_search_opt.current_iteration < grid_search_opt.max_suggestions
191 assert grid_search_opt.suggest()
192 assert list(grid_search_opt.pending_configs)
193 assert list(grid_search_opt.suggested_configs)
194 assert grid_search_opt.not_converged()
196 # Try to finish the rest of our iterations by repeating the grid.
197 while grid_search_opt.not_converged():
198 suggestion = grid_search_opt.suggest()
199 grid_search_opt.register(suggestion, status, score)
200 assert not grid_search_opt.not_converged()
201 assert grid_search_opt.current_iteration >= grid_search_opt.max_suggestions
202 assert list(grid_search_opt.pending_configs)
203 assert list(grid_search_opt.suggested_configs)
206def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None:
207 """Make sure that grid search optimizer works correctly when suggest and register
208 are called out of order.
209 """
210 # pylint: disable=too-many-locals
211 score: Dict[str, TunableValue] = {"score": 1.0, "other_score": 2.0}
212 status = Status.SUCCEEDED
213 suggest_count = 10
214 suggested = [grid_search_opt.suggest() for _ in range(suggest_count)]
215 suggested_shuffled = suggested.copy()
216 # Try to ensure the shuffled list is different.
217 for _ in range(3):
218 random.shuffle(suggested_shuffled)
219 if suggested_shuffled != suggested:
220 break
221 assert suggested != suggested_shuffled
223 for suggestion in suggested_shuffled:
224 suggestion_dict = suggestion.get_param_values()
225 assert suggestion_dict not in grid_search_opt.pending_configs
226 assert suggestion_dict in grid_search_opt.suggested_configs
227 grid_search_opt.register(suggestion, status, score)
228 assert suggestion_dict not in grid_search_opt.suggested_configs
230 best_score, best_config = grid_search_opt.get_best_observation()
231 assert best_score == score
233 # test re-register with higher score
234 best_suggestion = suggested_shuffled[0]
235 best_suggestion_dict = best_suggestion.get_param_values()
236 assert best_suggestion_dict not in grid_search_opt.pending_configs
237 assert best_suggestion_dict not in grid_search_opt.suggested_configs
239 best_suggestion_score: Dict[str, TunableValue] = {}
240 for opt_target, opt_dir in grid_search_opt.targets.items():
241 val = score[opt_target]
242 assert isinstance(val, (int, float))
243 best_suggestion_score[opt_target] = val - 1 if opt_dir == "min" else val + 1
245 grid_search_opt.register(best_suggestion, status, best_suggestion_score)
246 assert best_suggestion_dict not in grid_search_opt.suggested_configs
248 best_score, best_config = grid_search_opt.get_best_observation()
249 assert best_score == best_suggestion_score
250 assert best_config == best_suggestion
252 # Check bulk register
253 suggested = [grid_search_opt.suggest() for _ in range(suggest_count)]
254 assert all(
255 suggestion.get_param_values() not in grid_search_opt.pending_configs
256 for suggestion in suggested
257 )
258 assert all(
259 suggestion.get_param_values() in grid_search_opt.suggested_configs
260 for suggestion in suggested
261 )
263 # Those new suggestions also shouldn't be in the set of previously suggested configs.
264 assert all(suggestion.get_param_values() not in suggested_shuffled for suggestion in suggested)
266 grid_search_opt.bulk_register(
267 [suggestion.get_param_values() for suggestion in suggested],
268 [score] * len(suggested),
269 [status] * len(suggested),
270 )
272 assert all(
273 suggestion.get_param_values() not in grid_search_opt.pending_configs
274 for suggestion in suggested
275 )
276 assert all(
277 suggestion.get_param_values() not in grid_search_opt.suggested_configs
278 for suggestion in suggested
279 )
281 best_score, best_config = grid_search_opt.get_best_observation()
282 assert best_score == best_suggestion_score
283 assert best_config == best_suggestion
286def test_grid_search_register(
287 grid_search_opt: GridSearchOptimizer,
288 grid_search_tunables: TunableGroups,
289) -> None:
290 """Make sure that the `.register()` method adjusts the score signs correctly."""
291 assert grid_search_opt.register(
292 grid_search_tunables,
293 Status.SUCCEEDED,
294 {
295 "score": 1.0,
296 "other_score": 2.0,
297 },
298 ) == {
299 "score": -1.0, # max
300 "other_score": 2.0, # min
301 }
303 assert grid_search_opt.register(grid_search_tunables, Status.FAILED) == {
304 "score": float("inf"),
305 "other_score": float("inf"),
306 }