Coverage for mlos_bench/mlos_bench/schedulers/base_scheduler.py: 91%

130 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Base class for the optimization loop scheduling policies.""" 

6 

7import json 

8import logging 

9from abc import ABCMeta, abstractmethod 

10from datetime import datetime 

11from types import TracebackType 

12from typing import Any, Dict, List, Optional, Tuple, Type 

13 

14from pytz import UTC 

15from typing_extensions import Literal 

16 

17from mlos_bench.config.schemas import ConfigSchema 

18from mlos_bench.environments.base_environment import Environment 

19from mlos_bench.optimizers.base_optimizer import Optimizer 

20from mlos_bench.storage.base_storage import Storage 

21from mlos_bench.tunables.tunable_groups import TunableGroups 

22from mlos_bench.util import merge_parameters 

23 

24_LOG = logging.getLogger(__name__) 

25 

26 

27class Scheduler(metaclass=ABCMeta): 

28 # pylint: disable=too-many-instance-attributes 

29 """Base class for the optimization loop scheduling policies.""" 

30 

31 def __init__( # pylint: disable=too-many-arguments 

32 self, 

33 *, 

34 config: Dict[str, Any], 

35 global_config: Dict[str, Any], 

36 environment: Environment, 

37 optimizer: Optimizer, 

38 storage: Storage, 

39 root_env_config: str, 

40 ): 

41 """ 

42 Create a new instance of the scheduler. The constructor of this and the derived 

43 classes is called by the persistence service after reading the class JSON 

44 configuration. Other objects like the Environment and Optimizer are provided by 

45 the Launcher. 

46 

47 Parameters 

48 ---------- 

49 config : dict 

50 The configuration for the scheduler. 

51 global_config : dict 

52 he global configuration for the experiment. 

53 environment : Environment 

54 The environment to benchmark/optimize. 

55 optimizer : Optimizer 

56 The optimizer to use. 

57 storage : Storage 

58 The storage to use. 

59 root_env_config : str 

60 Path to the root environment configuration. 

61 """ 

62 self.global_config = global_config 

63 config = merge_parameters( 

64 dest=config.copy(), 

65 source=global_config, 

66 required_keys=["experiment_id", "trial_id"], 

67 ) 

68 self._validate_json_config(config) 

69 

70 self._experiment_id = config["experiment_id"].strip() 

71 self._trial_id = int(config["trial_id"]) 

72 self._config_id = int(config.get("config_id", -1)) 

73 self._max_trials = int(config.get("max_trials", -1)) 

74 self._trial_count = 0 

75 

76 self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1)) 

77 if self._trial_config_repeat_count <= 0: 

78 raise ValueError( 

79 f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}" 

80 ) 

81 

82 self._do_teardown = bool(config.get("teardown", True)) 

83 

84 self.experiment: Optional[Storage.Experiment] = None 

85 self.environment = environment 

86 self.optimizer = optimizer 

87 self.storage = storage 

88 self._root_env_config = root_env_config 

89 self._last_trial_id = -1 

90 self._ran_trials: List[Storage.Trial] = [] 

91 

92 _LOG.debug("Scheduler instantiated: %s :: %s", self, config) 

93 

94 def _validate_json_config(self, config: dict) -> None: 

95 """Reconstructs a basic json config that this class might have been instantiated 

96 from in order to validate configs provided outside the file loading 

97 mechanism. 

98 """ 

99 json_config: dict = { 

100 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

101 } 

102 if config: 

103 json_config["config"] = config.copy() 

104 # The json schema does not allow for -1 as a valid value for config_id. 

105 # As it is just a default placeholder value, and not required, we can 

106 # remove it from the config copy prior to validation safely. 

107 config_id = json_config["config"].get("config_id") 

108 if config_id is not None and isinstance(config_id, int) and config_id < 0: 

109 json_config["config"].pop("config_id") 

110 ConfigSchema.SCHEDULER.validate(json_config) 

111 

112 @property 

113 def trial_config_repeat_count(self) -> int: 

114 """Gets the number of trials to run for a given config.""" 

115 return self._trial_config_repeat_count 

116 

117 @property 

118 def trial_count(self) -> int: 

119 """Gets the current number of trials run for the experiment.""" 

120 return self._trial_count 

121 

122 @property 

123 def max_trials(self) -> int: 

124 """Gets the maximum number of trials to run for a given experiment, or -1 for no 

125 limit. 

126 """ 

127 return self._max_trials 

128 

129 def __repr__(self) -> str: 

130 """ 

131 Produce a human-readable version of the Scheduler (mostly for logging). 

132 

133 Returns 

134 ------- 

135 string : str 

136 A human-readable version of the Scheduler. 

137 """ 

138 return self.__class__.__name__ 

139 

140 def __enter__(self) -> "Scheduler": 

141 """Enter the scheduler's context.""" 

142 _LOG.debug("Scheduler START :: %s", self) 

143 assert self.experiment is None 

144 self.environment.__enter__() 

145 self.optimizer.__enter__() 

146 # Start new or resume the existing experiment. Verify that the 

147 # experiment configuration is compatible with the previous runs. 

148 # If the `merge` config parameter is present, merge in the data 

149 # from other experiments and check for compatibility. 

150 self.experiment = self.storage.experiment( 

151 experiment_id=self._experiment_id, 

152 trial_id=self._trial_id, 

153 root_env_config=self._root_env_config, 

154 description=self.environment.name, 

155 tunables=self.environment.tunable_params, 

156 opt_targets=self.optimizer.targets, 

157 ).__enter__() 

158 return self 

159 

160 def __exit__( 

161 self, 

162 ex_type: Optional[Type[BaseException]], 

163 ex_val: Optional[BaseException], 

164 ex_tb: Optional[TracebackType], 

165 ) -> Literal[False]: 

166 """Exit the context of the scheduler.""" 

167 if ex_val is None: 

168 _LOG.debug("Scheduler END :: %s", self) 

169 else: 

170 assert ex_type and ex_val 

171 _LOG.warning("Scheduler END :: %s", self, exc_info=(ex_type, ex_val, ex_tb)) 

172 assert self.experiment is not None 

173 self.experiment.__exit__(ex_type, ex_val, ex_tb) 

174 self.optimizer.__exit__(ex_type, ex_val, ex_tb) 

175 self.environment.__exit__(ex_type, ex_val, ex_tb) 

176 self.experiment = None 

177 return False # Do not suppress exceptions 

178 

179 @abstractmethod 

180 def start(self) -> None: 

181 """Start the optimization loop.""" 

182 assert self.experiment is not None 

183 _LOG.info( 

184 "START: Experiment: %s Env: %s Optimizer: %s", 

185 self.experiment, 

186 self.environment, 

187 self.optimizer, 

188 ) 

189 if _LOG.isEnabledFor(logging.INFO): 

190 _LOG.info("Root Environment:\n%s", self.environment.pprint()) 

191 

192 if self._config_id > 0: 

193 tunables = self.load_config(self._config_id) 

194 self.schedule_trial(tunables) 

195 

196 def teardown(self) -> None: 

197 """ 

198 Tear down the environment. 

199 

200 Call it after the completion of the `.start()` in the scheduler context. 

201 """ 

202 assert self.experiment is not None 

203 if self._do_teardown: 

204 self.environment.teardown() 

205 

206 def get_best_observation(self) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: 

207 """Get the best observation from the optimizer.""" 

208 (best_score, best_config) = self.optimizer.get_best_observation() 

209 _LOG.info("Env: %s best score: %s", self.environment, best_score) 

210 return (best_score, best_config) 

211 

212 def load_config(self, config_id: int) -> TunableGroups: 

213 """Load the existing tunable configuration from the storage.""" 

214 assert self.experiment is not None 

215 tunable_values = self.experiment.load_tunable_config(config_id) 

216 tunables = self.environment.tunable_params.assign(tunable_values) 

217 _LOG.info("Load config from storage: %d", config_id) 

218 if _LOG.isEnabledFor(logging.DEBUG): 

219 _LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2)) 

220 return tunables 

221 

222 def _schedule_new_optimizer_suggestions(self) -> bool: 

223 """ 

224 Optimizer part of the loop. 

225 

226 Load the results of the executed trials into the optimizer, suggest new 

227 configurations, and add them to the queue. Return True if optimization is not 

228 over, False otherwise. 

229 """ 

230 assert self.experiment is not None 

231 (trial_ids, configs, scores, status) = self.experiment.load(self._last_trial_id) 

232 _LOG.info("QUEUE: Update the optimizer with trial results: %s", trial_ids) 

233 self.optimizer.bulk_register(configs, scores, status) 

234 self._last_trial_id = max(trial_ids, default=self._last_trial_id) 

235 

236 not_done = self.not_done() 

237 if not_done: 

238 tunables = self.optimizer.suggest() 

239 self.schedule_trial(tunables) 

240 

241 return not_done 

242 

243 def schedule_trial(self, tunables: TunableGroups) -> None: 

244 """Add a configuration to the queue of trials.""" 

245 for repeat_i in range(1, self._trial_config_repeat_count + 1): 

246 self._add_trial_to_queue( 

247 tunables, 

248 config={ 

249 # Add some additional metadata to track for the trial such as the 

250 # optimizer config used. 

251 # Note: these values are unfortunately mutable at the moment. 

252 # Consider them as hints of what the config was the trial *started*. 

253 # It is possible that the experiment configs were changed 

254 # between resuming the experiment (since that is not currently 

255 # prevented). 

256 "optimizer": self.optimizer.name, 

257 "repeat_i": repeat_i, 

258 "is_defaults": tunables.is_defaults(), 

259 **{ 

260 f"opt_{key}_{i}": val 

261 for (i, opt_target) in enumerate(self.optimizer.targets.items()) 

262 for (key, val) in zip(["target", "direction"], opt_target) 

263 }, 

264 }, 

265 ) 

266 

267 def _add_trial_to_queue( 

268 self, 

269 tunables: TunableGroups, 

270 ts_start: Optional[datetime] = None, 

271 config: Optional[Dict[str, Any]] = None, 

272 ) -> None: 

273 """ 

274 Add a configuration to the queue of trials. 

275 

276 A wrapper for the `Experiment.new_trial` method. 

277 """ 

278 assert self.experiment is not None 

279 trial = self.experiment.new_trial(tunables, ts_start, config) 

280 _LOG.info("QUEUE: Add new trial: %s", trial) 

281 

282 def _run_schedule(self, running: bool = False) -> None: 

283 """ 

284 Scheduler part of the loop. 

285 

286 Check for pending trials in the queue and run them. 

287 """ 

288 assert self.experiment is not None 

289 for trial in self.experiment.pending_trials(datetime.now(UTC), running=running): 

290 self.run_trial(trial) 

291 

292 def not_done(self) -> bool: 

293 """ 

294 Check the stopping conditions. 

295 

296 By default, stop when the optimizer converges or max limit of trials reached. 

297 """ 

298 return self.optimizer.not_converged() and ( 

299 self._trial_count < self._max_trials or self._max_trials <= 0 

300 ) 

301 

302 @abstractmethod 

303 def run_trial(self, trial: Storage.Trial) -> None: 

304 """ 

305 Set up and run a single trial. 

306 

307 Save the results in the storage. 

308 """ 

309 assert self.experiment is not None 

310 self._trial_count += 1 

311 self._ran_trials.append(trial) 

312 _LOG.info("QUEUE: Execute trial # %d/%d :: %s", self._trial_count, self._max_trials, trial) 

313 

314 @property 

315 def ran_trials(self) -> List[Storage.Trial]: 

316 """Get the list of trials that were run.""" 

317 return self._ran_trials