Coverage for mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py: 92%

88 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Grid search optimizer for mlos_bench.""" 

6 

7import logging 

8from typing import Dict, Iterable, Optional, Sequence, Set, Tuple 

9 

10import ConfigSpace 

11import numpy as np 

12from ConfigSpace.util import generate_grid 

13 

14from mlos_bench.environments.status import Status 

15from mlos_bench.optimizers.convert_configspace import configspace_data_to_tunable_values 

16from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer 

17from mlos_bench.services.base_service import Service 

18from mlos_bench.tunables.tunable import TunableValue 

19from mlos_bench.tunables.tunable_groups import TunableGroups 

20 

21_LOG = logging.getLogger(__name__) 

22 

23 

24class GridSearchOptimizer(TrackBestOptimizer): 

25 """Grid search optimizer.""" 

26 

27 def __init__( 

28 self, 

29 tunables: TunableGroups, 

30 config: dict, 

31 global_config: Optional[dict] = None, 

32 service: Optional[Service] = None, 

33 ): 

34 super().__init__(tunables, config, global_config, service) 

35 

36 # Track the grid as a set of tuples of tunable values and reconstruct the 

37 # dicts as necessary. 

38 # Note: this is not the most efficient way to do this, but avoids 

39 # introducing a new data structure for hashable dicts. 

40 # See https://github.com/microsoft/MLOS/pull/690 for further discussion. 

41 

42 self._sanity_check() 

43 # The ordered set of pending configs that have not yet been suggested. 

44 self._config_keys, self._pending_configs = self._get_grid() 

45 assert self._pending_configs 

46 # A set of suggested configs that have not yet been registered. 

47 self._suggested_configs: Set[Tuple[TunableValue, ...]] = set() 

48 

49 def _sanity_check(self) -> None: 

50 size = np.prod([tunable.cardinality or np.inf for (tunable, _group) in self._tunables]) 

51 if size == np.inf: 

52 raise ValueError( 

53 f"Unquantized tunables are not supported for grid search: {self._tunables}" 

54 ) 

55 if size > 10000: 

56 _LOG.warning( 

57 "Large number %d of config points requested for grid search: %s", 

58 size, 

59 self._tunables, 

60 ) 

61 if size > self._max_suggestions: 

62 _LOG.warning( 

63 "Grid search size %d, is greater than max iterations %d", 

64 size, 

65 self._max_suggestions, 

66 ) 

67 

68 def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], None]]: 

69 """ 

70 Gets a grid of configs to try. 

71 

72 Order is given by ConfigSpace, but preserved by dict ordering semantics. 

73 """ 

74 # Since we are using ConfigSpace to generate the grid, but only tracking the 

75 # values as (ordered) tuples, we also need to use its ordering on column 

76 # names instead of the order given by TunableGroups. 

77 configs = [ 

78 configspace_data_to_tunable_values(dict(config)) 

79 for config in generate_grid( 

80 self.config_space, 

81 { 

82 tunable.name: tunable.cardinality or 0 # mypy wants an int 

83 for (tunable, _group) in self._tunables 

84 if tunable.is_numerical and tunable.cardinality 

85 }, 

86 ) 

87 ] 

88 names = set(tuple(configs.keys()) for configs in configs) 

89 assert len(names) == 1 

90 return names.pop(), {tuple(configs.values()): None for configs in configs} 

91 

92 @property 

93 def pending_configs(self) -> Iterable[Dict[str, TunableValue]]: 

94 """ 

95 Gets the set of pending configs in this grid search optimizer. 

96 

97 Returns 

98 ------- 

99 Iterable[Dict[str, TunableValue]] 

100 """ 

101 # See NOTEs above. 

102 return (dict(zip(self._config_keys, config)) for config in self._pending_configs.keys()) 

103 

104 @property 

105 def suggested_configs(self) -> Iterable[Dict[str, TunableValue]]: 

106 """ 

107 Gets the set of configs that have been suggested but not yet registered. 

108 

109 Returns 

110 ------- 

111 Iterable[Dict[str, TunableValue]] 

112 """ 

113 # See NOTEs above. 

114 return (dict(zip(self._config_keys, config)) for config in self._suggested_configs) 

115 

116 def bulk_register( 

117 self, 

118 configs: Sequence[dict], 

119 scores: Sequence[Optional[Dict[str, TunableValue]]], 

120 status: Optional[Sequence[Status]] = None, 

121 ) -> bool: 

122 if not super().bulk_register(configs, scores, status): 

123 return False 

124 if status is None: 

125 status = [Status.SUCCEEDED] * len(configs) 

126 for params, score, trial_status in zip(configs, scores, status): 

127 tunables = self._tunables.copy().assign(params) 

128 self.register(tunables, trial_status, score) 

129 if _LOG.isEnabledFor(logging.DEBUG): 

130 (best_score, _) = self.get_best_observation() 

131 _LOG.debug("Update END: %s = %s", self, best_score) 

132 return True 

133 

134 def suggest(self) -> TunableGroups: 

135 """Generate the next grid search suggestion.""" 

136 tunables = super().suggest() 

137 if self._start_with_defaults: 

138 _LOG.info("Use default values for the first trial") 

139 self._start_with_defaults = False 

140 tunables = tunables.restore_defaults() 

141 # Need to index based on ConfigSpace dict ordering. 

142 default_config = dict(self.config_space.get_default_configuration()) 

143 assert tunables.get_param_values() == default_config 

144 # Move the default from the pending to the suggested set. 

145 default_config_values = tuple(default_config.values()) 

146 del self._pending_configs[default_config_values] 

147 self._suggested_configs.add(default_config_values) 

148 else: 

149 # Select the first item from the pending configs. 

150 if not self._pending_configs and self._iter <= self._max_suggestions: 

151 _LOG.info("No more pending configs to suggest. Restarting grid.") 

152 self._config_keys, self._pending_configs = self._get_grid() 

153 try: 

154 next_config_values = next(iter(self._pending_configs.keys())) 

155 except StopIteration as exc: 

156 raise ValueError("No more pending configs to suggest.") from exc 

157 next_config = dict(zip(self._config_keys, next_config_values)) 

158 tunables.assign(next_config) 

159 # Move it to the suggested set. 

160 self._suggested_configs.add(next_config_values) 

161 del self._pending_configs[next_config_values] 

162 _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables) 

163 return tunables 

164 

165 def register( 

166 self, 

167 tunables: TunableGroups, 

168 status: Status, 

169 score: Optional[Dict[str, TunableValue]] = None, 

170 ) -> Optional[Dict[str, float]]: 

171 registered_score = super().register(tunables, status, score) 

172 try: 

173 config = dict( 

174 ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values()) 

175 ) 

176 self._suggested_configs.remove(tuple(config.values())) 

177 except KeyError: 

178 _LOG.warning( 

179 ( 

180 "Attempted to remove missing config " 

181 "(previously registered?) from suggested set: %s" 

182 ), 

183 tunables, 

184 ) 

185 return registered_score 

186 

187 def not_converged(self) -> bool: 

188 if self._iter > self._max_suggestions: 

189 if bool(self._pending_configs): 

190 _LOG.warning( 

191 "Exceeded max iterations, but still have %d pending configs: %s", 

192 len(self._pending_configs), 

193 list(self._pending_configs.keys()), 

194 ) 

195 return False 

196 return bool(self._pending_configs)