Coverage for mlos_bench/mlos_bench/schedulers/base_scheduler.py: 89%

195 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Base class for the optimization loop scheduling policies.""" 

6 

7import json 

8import logging 

9from abc import ABCMeta, abstractmethod 

10from collections.abc import Iterable 

11from contextlib import AbstractContextManager as ContextManager 

12from datetime import datetime 

13from types import TracebackType 

14from typing import Any, Literal 

15 

16from pytz import UTC 

17 

18from mlos_bench.config.schemas import ConfigSchema 

19from mlos_bench.environments.base_environment import Environment 

20from mlos_bench.optimizers.base_optimizer import Optimizer 

21from mlos_bench.schedulers.trial_runner import TrialRunner 

22from mlos_bench.storage.base_storage import Storage 

23from mlos_bench.tunables.tunable_groups import TunableGroups 

24from mlos_bench.util import merge_parameters 

25 

26_LOG = logging.getLogger(__name__) 

27 

28 

29class Scheduler(ContextManager, metaclass=ABCMeta): 

30 # pylint: disable=too-many-instance-attributes 

31 """Base class for the optimization loop scheduling policies.""" 

32 

33 def __init__( # pylint: disable=too-many-arguments 

34 self, 

35 *, 

36 config: dict[str, Any], 

37 global_config: dict[str, Any], 

38 trial_runners: Iterable[TrialRunner], 

39 optimizer: Optimizer, 

40 storage: Storage, 

41 root_env_config: str, 

42 ): 

43 """ 

44 Create a new instance of the scheduler. The constructor of this and the derived 

45 classes is called by the persistence service after reading the class JSON 

46 configuration. Other objects like the TrialRunner(s) and their Environment(s) 

47 and Optimizer are provided by the Launcher. 

48 

49 Parameters 

50 ---------- 

51 config : dict 

52 The configuration for the Scheduler. 

53 global_config : dict 

54 The global configuration for the Experiment. 

55 trial_runner : Iterable[TrialRunner] 

56 The set of TrialRunner(s) (and associated Environment(s)) to benchmark/optimize. 

57 optimizer : Optimizer 

58 The Optimizer to use. 

59 storage : Storage 

60 The Storage to use. 

61 root_env_config : str 

62 Path to the root Environment configuration. 

63 """ 

64 self.global_config = global_config 

65 config = merge_parameters( 

66 dest=config.copy(), 

67 source=global_config, 

68 required_keys=["experiment_id", "trial_id"], 

69 ) 

70 self._validate_json_config(config) 

71 

72 self._in_context = False 

73 self._experiment_id = config["experiment_id"].strip() 

74 self._trial_id = int(config["trial_id"]) 

75 self._config_id = int(config.get("config_id", -1)) 

76 self._max_trials = int(config.get("max_trials", -1)) 

77 self._trial_count = 0 

78 

79 self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1)) 

80 if self._trial_config_repeat_count <= 0: 

81 raise ValueError( 

82 f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}" 

83 ) 

84 

85 self._do_teardown = bool(config.get("teardown", True)) 

86 

87 self._experiment: Storage.Experiment | None = None 

88 

89 assert trial_runners, "At least one TrialRunner is required" 

90 trial_runners = list(trial_runners) 

91 self._trial_runners = { 

92 trial_runner.trial_runner_id: trial_runner for trial_runner in trial_runners 

93 } 

94 self._current_trial_runner_idx = 0 

95 self._trial_runner_ids = list(self._trial_runners.keys()) 

96 assert len(self._trial_runner_ids) == len( 

97 trial_runners 

98 ), f"Duplicate TrialRunner ids detected: {trial_runners}" 

99 

100 self._optimizer = optimizer 

101 self._storage = storage 

102 self._root_env_config = root_env_config 

103 self._last_trial_id = -1 

104 self._ran_trials: list[Storage.Trial] = [] 

105 

106 _LOG.debug("Scheduler instantiated: %s :: %s", self, config) 

107 

108 def _validate_json_config(self, config: dict) -> None: 

109 """Reconstructs a basic json config that this class might have been instantiated 

110 from in order to validate configs provided outside the file loading 

111 mechanism. 

112 """ 

113 json_config: dict = { 

114 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

115 } 

116 if config: 

117 json_config["config"] = config.copy() 

118 # The json schema does not allow for -1 as a valid value for config_id. 

119 # As it is just a default placeholder value, and not required, we can 

120 # remove it from the config copy prior to validation safely. 

121 config_id = json_config["config"].get("config_id") 

122 if config_id is not None and isinstance(config_id, int) and config_id < 0: 

123 json_config["config"].pop("config_id") 

124 ConfigSchema.SCHEDULER.validate(json_config) 

125 

126 @property 

127 def trial_config_repeat_count(self) -> int: 

128 """Gets the number of trials to run for a given config.""" 

129 return self._trial_config_repeat_count 

130 

131 @property 

132 def trial_count(self) -> int: 

133 """Gets the current number of trials run for the experiment.""" 

134 return self._trial_count 

135 

136 @property 

137 def max_trials(self) -> int: 

138 """Gets the maximum number of trials to run for a given experiment, or -1 for no 

139 limit. 

140 """ 

141 return self._max_trials 

142 

143 @property 

144 def experiment(self) -> Storage.Experiment | None: 

145 """Gets the Experiment Storage.""" 

146 return self._experiment 

147 

148 @property 

149 def _root_trial_runner_id(self) -> int: 

150 # Use the first TrialRunner as the root. 

151 return self._trial_runner_ids[0] 

152 

153 @property 

154 def root_environment(self) -> Environment: 

155 """ 

156 Gets the root (prototypical) Environment from the first TrialRunner. 

157 

158 Notes 

159 ----- 

160 All TrialRunners have the same Environment config and are made 

161 unique by their use of the unique trial_runner_id assigned to each 

162 TrialRunner's Environment's global_config. 

163 """ 

164 # Use the first TrialRunner's Environment as the root Environment. 

165 return self._trial_runners[self._root_trial_runner_id].environment 

166 

167 @property 

168 def trial_runners(self) -> dict[int, TrialRunner]: 

169 """Gets the set of Trial Runners.""" 

170 return self._trial_runners 

171 

172 @property 

173 def environments(self) -> Iterable[Environment]: 

174 """Gets the Environment from the TrialRunners.""" 

175 return (trial_runner.environment for trial_runner in self._trial_runners.values()) 

176 

177 @property 

178 def optimizer(self) -> Optimizer: 

179 """Gets the Optimizer.""" 

180 return self._optimizer 

181 

182 @property 

183 def storage(self) -> Storage: 

184 """Gets the Storage.""" 

185 return self._storage 

186 

187 def __repr__(self) -> str: 

188 """ 

189 Produce a human-readable version of the Scheduler (mostly for logging). 

190 

191 Returns 

192 ------- 

193 string : str 

194 A human-readable version of the Scheduler. 

195 """ 

196 return self.__class__.__name__ 

197 

198 def __enter__(self) -> "Scheduler": 

199 """Enter the scheduler's context.""" 

200 _LOG.debug("Scheduler START :: %s", self) 

201 assert self.experiment is None 

202 assert not self._in_context 

203 for trial_runner in self._trial_runners.values(): 

204 trial_runner.__enter__() 

205 self._optimizer.__enter__() 

206 # Start new or resume the existing experiment. Verify that the 

207 # experiment configuration is compatible with the previous runs. 

208 # If the `merge` config parameter is present, merge in the data 

209 # from other experiments and check for compatibility. 

210 self._experiment = self.storage.experiment( 

211 experiment_id=self._experiment_id, 

212 trial_id=self._trial_id, 

213 root_env_config=self._root_env_config, 

214 description=self.root_environment.name, 

215 tunables=self.root_environment.tunable_params, 

216 opt_targets=self.optimizer.targets, 

217 ).__enter__() 

218 self._in_context = True 

219 return self 

220 

221 def __exit__( 

222 self, 

223 ex_type: type[BaseException] | None, 

224 ex_val: BaseException | None, 

225 ex_tb: TracebackType | None, 

226 ) -> Literal[False]: 

227 """Exit the context of the scheduler.""" 

228 if ex_val is None: 

229 _LOG.debug("Scheduler END :: %s", self) 

230 else: 

231 assert ex_type and ex_val 

232 _LOG.warning("Scheduler END :: %s", self, exc_info=(ex_type, ex_val, ex_tb)) 

233 assert self._in_context 

234 assert self._experiment is not None 

235 self._experiment.__exit__(ex_type, ex_val, ex_tb) 

236 self._optimizer.__exit__(ex_type, ex_val, ex_tb) 

237 for trial_runner in self._trial_runners.values(): 

238 trial_runner.__exit__(ex_type, ex_val, ex_tb) 

239 self._experiment = None 

240 self._in_context = False 

241 return False # Do not suppress exceptions 

242 

243 @abstractmethod 

244 def start(self) -> None: 

245 """Start the scheduling loop.""" 

246 assert self.experiment is not None 

247 _LOG.info( 

248 "START: Experiment: %s Env: %s Optimizer: %s", 

249 self._experiment, 

250 self.root_environment, 

251 self.optimizer, 

252 ) 

253 if _LOG.isEnabledFor(logging.INFO): 

254 _LOG.info("Root Environment:\n%s", self.root_environment.pprint()) 

255 

256 if self._config_id > 0: 

257 tunables = self.load_tunable_config(self._config_id) 

258 self.schedule_trial(tunables) 

259 

260 def teardown(self) -> None: 

261 """ 

262 Tear down the TrialRunners/Environment(s). 

263 

264 Call it after the completion of the `.start()` in the scheduler context. 

265 """ 

266 assert self.experiment is not None 

267 if self._do_teardown: 

268 for trial_runner in self._trial_runners.values(): 

269 assert not trial_runner.is_running 

270 trial_runner.teardown() 

271 

272 def get_best_observation(self) -> tuple[dict[str, float] | None, TunableGroups | None]: 

273 """Get the best observation from the optimizer.""" 

274 (best_score, best_config) = self.optimizer.get_best_observation() 

275 _LOG.info("Env: %s best score: %s", self.root_environment, best_score) 

276 return (best_score, best_config) 

277 

278 def load_tunable_config(self, config_id: int) -> TunableGroups: 

279 """Load the existing tunable configuration from the storage.""" 

280 assert self.experiment is not None 

281 tunable_values = self.experiment.load_tunable_config(config_id) 

282 tunables = TunableGroups() 

283 for environment in self.environments: 

284 tunables = environment.tunable_params.assign(tunable_values) 

285 _LOG.info("Load config from storage: %d", config_id) 

286 if _LOG.isEnabledFor(logging.DEBUG): 

287 _LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2)) 

288 return tunables.copy() 

289 

290 def _schedule_new_optimizer_suggestions(self) -> bool: 

291 """ 

292 Optimizer part of the loop. 

293 

294 Load the results of the executed trials into the optimizer, suggest new 

295 configurations, and add them to the queue. Return True if optimization is not 

296 over, False otherwise. 

297 """ 

298 assert self.experiment is not None 

299 (trial_ids, configs, scores, status) = self.experiment.load(self._last_trial_id) 

300 _LOG.info("QUEUE: Update the optimizer with trial results: %s", trial_ids) 

301 self.optimizer.bulk_register(configs, scores, status) 

302 self._last_trial_id = max(trial_ids, default=self._last_trial_id) 

303 

304 not_done = self.not_done() 

305 if not_done: 

306 tunables = self.optimizer.suggest() 

307 self.schedule_trial(tunables) 

308 

309 return not_done 

310 

311 def schedule_trial(self, tunables: TunableGroups) -> None: 

312 """Add a configuration to the queue of trials.""" 

313 # TODO: Alternative scheduling policies may prefer to expand repeats over 

314 # time as well as space, or adjust the number of repeats (budget) of a given 

315 # trial based on whether initial results are promising. 

316 for repeat_i in range(1, self._trial_config_repeat_count + 1): 

317 self._add_trial_to_queue( 

318 tunables, 

319 config={ 

320 # Add some additional metadata to track for the trial such as the 

321 # optimizer config used. 

322 # Note: these values are unfortunately mutable at the moment. 

323 # Consider them as hints of what the config was the trial *started*. 

324 # It is possible that the experiment configs were changed 

325 # between resuming the experiment (since that is not currently 

326 # prevented). 

327 "optimizer": self.optimizer.name, 

328 "repeat_i": repeat_i, 

329 "is_defaults": tunables.is_defaults(), 

330 **{ 

331 f"opt_{key}_{i}": val 

332 for (i, opt_target) in enumerate(self.optimizer.targets.items()) 

333 for (key, val) in zip(["target", "direction"], opt_target) 

334 }, 

335 }, 

336 ) 

337 

338 def _add_trial_to_queue( 

339 self, 

340 tunables: TunableGroups, 

341 ts_start: datetime | None = None, 

342 config: dict[str, Any] | None = None, 

343 ) -> None: 

344 """ 

345 Add a configuration to the queue of trials in the Storage backend. 

346 

347 A wrapper for the `Experiment.new_trial` method. 

348 """ 

349 assert self.experiment is not None 

350 trial = self.experiment.new_trial(tunables, ts_start, config) 

351 _LOG.info("QUEUE: Added new trial: %s", trial) 

352 

353 def assign_trial_runners(self, trials: Iterable[Storage.Trial]) -> None: 

354 """ 

355 Assigns TrialRunners to the given Trial in batch. 

356 

357 The base class implements a simple round-robin scheduling algorithm for each 

358 Trial in sequence. 

359 

360 Subclasses can override this method to implement a more sophisticated policy. 

361 For instance:: 

362 

363 def assign_trial_runners( 

364 self, 

365 trials: Iterable[Storage.Trial], 

366 ) -> TrialRunner: 

367 trial_runners_map = {} 

368 # Implement a more sophisticated policy here. 

369 # For example, to assign the Trial to the TrialRunner with the least 

370 # number of running Trials. 

371 # Or assign the Trial to the TrialRunner that hasn't executed this 

372 # TunableValues Config yet. 

373 for (trial, trial_runner) in trial_runners_map: 

374 # Call the base class method to assign the TrialRunner in the Trial's metadata. 

375 trial.set_trial_runner(trial_runner) 

376 ... 

377 

378 Parameters 

379 ---------- 

380 trials : Iterable[Storage.Trial] 

381 The trial to assign a TrialRunner to. 

382 """ 

383 for trial in trials: 

384 if trial.trial_runner_id is not None: 

385 _LOG.info( 

386 "Trial %s already has a TrialRunner assigned: %s", 

387 trial, 

388 trial.trial_runner_id, 

389 ) 

390 continue 

391 

392 # Basic round-robin trial runner assignment policy: 

393 # fetch and increment the current TrialRunner index. 

394 # Override in the subclass for a more sophisticated policy. 

395 trial_runner_idx = self._current_trial_runner_idx 

396 self._current_trial_runner_idx += 1 

397 self._current_trial_runner_idx %= len(self._trial_runner_ids) 

398 trial_runner = self._trial_runners[self._trial_runner_ids[trial_runner_idx]] 

399 assert trial_runner 

400 _LOG.info( 

401 "Assigning TrialRunner %s to Trial %s via basic round-robin policy.", 

402 trial_runner, 

403 trial, 

404 ) 

405 assigned_trial_runner_id = trial.set_trial_runner(trial_runner.trial_runner_id) 

406 if assigned_trial_runner_id != trial_runner.trial_runner_id: 

407 raise ValueError( 

408 f"Failed to assign TrialRunner {trial_runner} to Trial {trial}: " 

409 f"{assigned_trial_runner_id}" 

410 ) 

411 

412 def get_trial_runner(self, trial: Storage.Trial) -> TrialRunner: 

413 """ 

414 Gets the TrialRunner associated with the given Trial. 

415 

416 Parameters 

417 ---------- 

418 trial : Storage.Trial 

419 The trial to get the associated TrialRunner for. 

420 

421 Returns 

422 ------- 

423 TrialRunner 

424 """ 

425 if trial.trial_runner_id is None: 

426 self.assign_trial_runners([trial]) 

427 assert trial.trial_runner_id is not None 

428 trial_runner = self._trial_runners.get(trial.trial_runner_id) 

429 if trial_runner is None: 

430 raise ValueError( 

431 f"TrialRunner {trial.trial_runner_id} for Trial {trial} " 

432 f"not found: {self._trial_runners}" 

433 ) 

434 assert trial_runner.trial_runner_id == trial.trial_runner_id 

435 return trial_runner 

436 

437 def _run_schedule(self, running: bool = False) -> None: 

438 """ 

439 Scheduler part of the loop. 

440 

441 Check for pending trials in the queue and run them. 

442 """ 

443 assert self.experiment is not None 

444 # Make sure that any pending trials have a TrialRunner assigned. 

445 pending_trials = list(self.experiment.pending_trials(datetime.now(UTC), running=running)) 

446 self.assign_trial_runners(pending_trials) 

447 for trial in pending_trials: 

448 self.run_trial(trial) 

449 

450 def not_done(self) -> bool: 

451 """ 

452 Check the stopping conditions. 

453 

454 By default, stop when the optimizer converges or max limit of trials reached. 

455 """ 

456 return self.optimizer.not_converged() and ( 

457 self._trial_count < self._max_trials or self._max_trials <= 0 

458 ) 

459 

460 @abstractmethod 

461 def run_trial(self, trial: Storage.Trial) -> None: 

462 """ 

463 Set up and run a single trial. 

464 

465 Save the results in the storage. 

466 """ 

467 assert self._in_context 

468 assert self.experiment is not None 

469 self._trial_count += 1 

470 self._ran_trials.append(trial) 

471 _LOG.info("QUEUE: Execute trial # %d/%d :: %s", self._trial_count, self._max_trials, trial) 

472 

473 @property 

474 def ran_trials(self) -> list[Storage.Trial]: 

475 """Get the list of trials that were run.""" 

476 return self._ran_trials