Coverage for mlos_bench/mlos_bench/storage/base_storage.py: 96%

166 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Base interface for saving and restoring the benchmark data. 

7 

8See Also 

9-------- 

10mlos_bench.storage.base_storage.Storage.experiments : 

11 Retrieves a dictionary of the Experiments' data. 

12mlos_bench.storage.base_experiment_data.ExperimentData.results_df : 

13 Retrieves a pandas DataFrame of the Experiment's trials' results data. 

14mlos_bench.storage.base_experiment_data.ExperimentData.trials : 

15 Retrieves a dictionary of the Experiment's trials' data. 

16mlos_bench.storage.base_experiment_data.ExperimentData.tunable_configs : 

17 Retrieves a dictionary of the Experiment's sampled configs data. 

18mlos_bench.storage.base_experiment_data.ExperimentData.tunable_config_trial_groups : 

19 Retrieves a dictionary of the Experiment's trials' data, grouped by shared 

20 tunable config. 

21mlos_bench.storage.base_trial_data.TrialData : 

22 Base interface for accessing the stored benchmark trial data. 

23""" 

24 

25import logging 

26from abc import ABCMeta, abstractmethod 

27from collections.abc import Iterator, Mapping 

28from contextlib import AbstractContextManager as ContextManager 

29from datetime import datetime 

30from types import TracebackType 

31from typing import Any, Literal 

32 

33from mlos_bench.config.schemas import ConfigSchema 

34from mlos_bench.dict_templater import DictTemplater 

35from mlos_bench.environments.status import Status 

36from mlos_bench.services.base_service import Service 

37from mlos_bench.storage.base_experiment_data import ExperimentData 

38from mlos_bench.tunables.tunable_groups import TunableGroups 

39from mlos_bench.util import get_git_info 

40 

41_LOG = logging.getLogger(__name__) 

42 

43 

44class Storage(metaclass=ABCMeta): 

45 """An abstract interface between the benchmarking framework and storage systems 

46 (e.g., SQLite or MLFLow). 

47 """ 

48 

49 def __init__( 

50 self, 

51 config: dict[str, Any], 

52 global_config: dict | None = None, 

53 service: Service | None = None, 

54 ): 

55 """ 

56 Create a new storage object. 

57 

58 Parameters 

59 ---------- 

60 config : dict 

61 Free-format key/value pairs of configuration parameters. 

62 """ 

63 _LOG.debug("Storage config: %s", config) 

64 self._validate_json_config(config) 

65 self._service = service 

66 self._config = config.copy() 

67 self._global_config = global_config or {} 

68 

69 @abstractmethod 

70 def update_schema(self) -> None: 

71 """Update the schema of the storage backend if needed.""" 

72 

73 def _validate_json_config(self, config: dict) -> None: 

74 """Reconstructs a basic json config that this class might have been instantiated 

75 from in order to validate configs provided outside the file loading 

76 mechanism. 

77 """ 

78 json_config: dict = { 

79 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

80 } 

81 if config: 

82 json_config["config"] = config 

83 ConfigSchema.STORAGE.validate(json_config) 

84 

85 @property 

86 @abstractmethod 

87 def experiments(self) -> dict[str, ExperimentData]: 

88 """ 

89 Retrieve the experiments' data from the storage. 

90 

91 Returns 

92 ------- 

93 experiments : dict[str, ExperimentData] 

94 A dictionary of the experiments' data, keyed by experiment id. 

95 """ 

96 

97 @abstractmethod 

98 def experiment( # pylint: disable=too-many-arguments 

99 self, 

100 *, 

101 experiment_id: str, 

102 trial_id: int, 

103 root_env_config: str, 

104 description: str, 

105 tunables: TunableGroups, 

106 opt_targets: dict[str, Literal["min", "max"]], 

107 ) -> "Storage.Experiment": 

108 """ 

109 Create a new experiment in the storage. 

110 

111 We need the `opt_target` parameter here to know what metric to retrieve 

112 when we load the data from previous trials. Later we will replace it with 

113 full metadata about the optimization direction, multiple objectives, etc. 

114 

115 Parameters 

116 ---------- 

117 experiment_id : str 

118 Unique identifier of the experiment. 

119 trial_id : int 

120 Starting number of the trial. 

121 root_env_config : str 

122 A path to the root JSON configuration file of the benchmarking environment. 

123 description : str 

124 Human-readable description of the experiment. 

125 tunables : TunableGroups 

126 opt_targets : dict[str, Literal["min", "max"]] 

127 Names of metrics we're optimizing for and the optimization direction {min, max}. 

128 

129 Returns 

130 ------- 

131 experiment : Storage.Experiment 

132 An object that allows to update the storage with 

133 the results of the experiment and related data. 

134 """ 

135 

136 class Experiment(ContextManager, metaclass=ABCMeta): 

137 # pylint: disable=too-many-instance-attributes 

138 """ 

139 Base interface for storing the results of the experiment. 

140 

141 This class is instantiated in the `Storage.experiment()` method. 

142 """ 

143 

144 def __init__( # pylint: disable=too-many-arguments 

145 self, 

146 *, 

147 tunables: TunableGroups, 

148 experiment_id: str, 

149 trial_id: int, 

150 root_env_config: str, 

151 description: str, 

152 opt_targets: dict[str, Literal["min", "max"]], 

153 ): 

154 self._tunables = tunables.copy() 

155 self._trial_id = trial_id 

156 self._experiment_id = experiment_id 

157 (self._git_repo, self._git_commit, self._root_env_config) = get_git_info( 

158 root_env_config 

159 ) 

160 self._description = description 

161 self._opt_targets = opt_targets 

162 self._in_context = False 

163 

164 def __enter__(self) -> "Storage.Experiment": 

165 """ 

166 Enter the context of the experiment. 

167 

168 Override the `_setup` method to add custom context initialization. 

169 """ 

170 _LOG.debug("Starting experiment: %s", self) 

171 assert not self._in_context 

172 self._setup() 

173 self._in_context = True 

174 return self 

175 

176 def __exit__( 

177 self, 

178 exc_type: type[BaseException] | None, 

179 exc_val: BaseException | None, 

180 exc_tb: TracebackType | None, 

181 ) -> Literal[False]: 

182 """ 

183 End the context of the experiment. 

184 

185 Override the `_teardown` method to add custom context teardown logic. 

186 """ 

187 is_ok = exc_val is None 

188 if is_ok: 

189 _LOG.debug("Finishing experiment: %s", self) 

190 else: 

191 assert exc_type and exc_val 

192 _LOG.warning( 

193 "Finishing experiment: %s", 

194 self, 

195 exc_info=(exc_type, exc_val, exc_tb), 

196 ) 

197 assert self._in_context 

198 self._teardown(is_ok) 

199 self._in_context = False 

200 return False # Do not suppress exceptions 

201 

202 def __repr__(self) -> str: 

203 return self._experiment_id 

204 

205 def _setup(self) -> None: 

206 """ 

207 Create a record of the new experiment or find an existing one in the 

208 storage. 

209 

210 This method is called by `Storage.Experiment.__enter__()`. 

211 """ 

212 

213 def _teardown(self, is_ok: bool) -> None: 

214 """ 

215 Finalize the experiment in the storage. 

216 

217 This method is called by `Storage.Experiment.__exit__()`. 

218 

219 Parameters 

220 ---------- 

221 is_ok : bool 

222 True if there were no exceptions during the experiment, False otherwise. 

223 """ 

224 

225 @property 

226 def experiment_id(self) -> str: 

227 """Get the Experiment's ID.""" 

228 return self._experiment_id 

229 

230 @property 

231 def trial_id(self) -> int: 

232 """Get the current Trial ID.""" 

233 return self._trial_id 

234 

235 @property 

236 def description(self) -> str: 

237 """Get the Experiment's description.""" 

238 return self._description 

239 

240 @property 

241 def root_env_config(self) -> str: 

242 """Get the Experiment's root Environment config file path.""" 

243 return self._root_env_config 

244 

245 @property 

246 def tunables(self) -> TunableGroups: 

247 """Get the Experiment's tunables.""" 

248 return self._tunables 

249 

250 @property 

251 def opt_targets(self) -> dict[str, Literal["min", "max"]]: 

252 """Get the Experiment's optimization targets and directions.""" 

253 return self._opt_targets 

254 

255 @abstractmethod 

256 def merge(self, experiment_ids: list[str]) -> None: 

257 """ 

258 Merge in the results of other (compatible) experiments trials. Used to help 

259 warm up the optimizer for this experiment. 

260 

261 Parameters 

262 ---------- 

263 experiment_ids : list[str] 

264 List of IDs of the experiments to merge in. 

265 """ 

266 

267 @abstractmethod 

268 def load_tunable_config(self, config_id: int) -> dict[str, Any]: 

269 """Load tunable values for a given config ID.""" 

270 

271 @abstractmethod 

272 def load_telemetry(self, trial_id: int) -> list[tuple[datetime, str, Any]]: 

273 """ 

274 Retrieve the telemetry data for a given trial. 

275 

276 Parameters 

277 ---------- 

278 trial_id : int 

279 Trial ID. 

280 

281 Returns 

282 ------- 

283 metrics : list[tuple[datetime.datetime, str, Any]] 

284 Telemetry data. 

285 """ 

286 

287 @abstractmethod 

288 def load( 

289 self, 

290 last_trial_id: int = -1, 

291 ) -> tuple[list[int], list[dict], list[dict[str, Any] | None], list[Status]]: 

292 """ 

293 Load (tunable values, benchmark scores, status) to warm-up the optimizer. 

294 

295 If `last_trial_id` is present, load only the data from the (completed) trials 

296 that were scheduled *after* the given trial ID. Otherwise, return data from ALL 

297 merged-in experiments and attempt to impute the missing tunable values. 

298 

299 Parameters 

300 ---------- 

301 last_trial_id : int 

302 (Optional) Trial ID to start from. 

303 

304 Returns 

305 ------- 

306 (trial_ids, configs, scores, status) : ([int], [dict], [dict] | None, [Status]) 

307 Trial ids, Tunable values, benchmark scores, and status of the trials. 

308 """ 

309 

310 @abstractmethod 

311 def pending_trials( 

312 self, 

313 timestamp: datetime, 

314 *, 

315 running: bool, 

316 ) -> Iterator["Storage.Trial"]: 

317 """ 

318 Return an iterator over the pending trials that are scheduled to run on or 

319 before the specified timestamp. 

320 

321 Parameters 

322 ---------- 

323 timestamp : datetime.datetime 

324 The time in UTC to check for scheduled trials. 

325 running : bool 

326 If True, include the trials that are already running. 

327 Otherwise, return only the scheduled trials. 

328 

329 Returns 

330 ------- 

331 trials : Iterator[Storage.Trial] 

332 An iterator over the scheduled (and maybe running) trials. 

333 """ 

334 

335 def new_trial( 

336 self, 

337 tunables: TunableGroups, 

338 ts_start: datetime | None = None, 

339 config: dict[str, Any] | None = None, 

340 ) -> "Storage.Trial": 

341 """ 

342 Create a new experiment run in the storage. 

343 

344 Parameters 

345 ---------- 

346 tunables : TunableGroups 

347 Tunable parameters to use for the trial. 

348 ts_start : datetime.datetime | None 

349 Timestamp of the trial start (can be in the future). 

350 config : dict 

351 Key/value pairs of additional non-tunable parameters of the trial. 

352 

353 Returns 

354 ------- 

355 trial : Storage.Trial 

356 An object that allows to update the storage with 

357 the results of the experiment trial run. 

358 """ 

359 # Check that `config` is json serializable (e.g., no callables) 

360 if config: 

361 try: 

362 # Relies on the fact that DictTemplater only accepts primitive 

363 # types in it's nested dict structure walk. 

364 _config = DictTemplater(config).expand_vars() 

365 assert isinstance(_config, dict) 

366 except ValueError as e: 

367 _LOG.error("Non-serializable config: %s", config, exc_info=e) 

368 raise e 

369 return self._new_trial(tunables, ts_start, config) 

370 

371 @abstractmethod 

372 def _new_trial( 

373 self, 

374 tunables: TunableGroups, 

375 ts_start: datetime | None = None, 

376 config: dict[str, Any] | None = None, 

377 ) -> "Storage.Trial": 

378 """ 

379 Create a new experiment run in the storage. 

380 

381 Parameters 

382 ---------- 

383 tunables : TunableGroups 

384 Tunable parameters to use for the trial. 

385 ts_start : datetime.datetime | None 

386 Timestamp of the trial start (can be in the future). 

387 config : dict 

388 Key/value pairs of additional non-tunable parameters of the trial. 

389 

390 Returns 

391 ------- 

392 trial : Storage.Trial 

393 An object that allows to update the storage with 

394 the results of the experiment trial run. 

395 """ 

396 

397 class Trial(metaclass=ABCMeta): 

398 # pylint: disable=too-many-instance-attributes 

399 """ 

400 Base interface for storing the results of a single run of the experiment. 

401 

402 This class is instantiated in the `Storage.Experiment.trial()` method. 

403 """ 

404 

405 def __init__( # pylint: disable=too-many-arguments 

406 self, 

407 *, 

408 tunables: TunableGroups, 

409 experiment_id: str, 

410 trial_id: int, 

411 tunable_config_id: int, 

412 trial_runner_id: int | None = None, 

413 opt_targets: dict[str, Literal["min", "max"]], 

414 config: dict[str, Any] | None = None, 

415 status: Status = Status.UNKNOWN, 

416 ): 

417 if status not in (Status.UNKNOWN, Status.PENDING): 

418 raise ValueError(f"Invalid status for a new trial: {status}") 

419 self._tunables = tunables 

420 self._experiment_id = experiment_id 

421 self._trial_id = trial_id 

422 self._tunable_config_id = tunable_config_id 

423 self._trial_runner_id = trial_runner_id 

424 self._opt_targets = opt_targets 

425 self._config = config or {} 

426 self._status = status 

427 

428 def __repr__(self) -> str: 

429 return ( 

430 f"{self._experiment_id}:{self._trial_id}:" 

431 f"{self._tunable_config_id}:{self.trial_runner_id}" 

432 ) 

433 

434 @property 

435 def trial_id(self) -> int: 

436 """ID of the current trial.""" 

437 return self._trial_id 

438 

439 @property 

440 def tunable_config_id(self) -> int: 

441 """ID of the current trial (tunable) configuration.""" 

442 return self._tunable_config_id 

443 

444 @property 

445 def trial_runner_id(self) -> int | None: 

446 """ID of the TrialRunner this trial is assigned to.""" 

447 return self._trial_runner_id 

448 

449 def opt_targets(self) -> dict[str, Literal["min", "max"]]: 

450 """Get the Trial's optimization targets and directions.""" 

451 return self._opt_targets 

452 

453 @property 

454 def tunables(self) -> TunableGroups: 

455 """ 

456 Tunable parameters of the current trial. 

457 

458 (e.g., application Environment's "config") 

459 """ 

460 return self._tunables 

461 

462 @abstractmethod 

463 def set_trial_runner(self, trial_runner_id: int) -> int: 

464 """Assign the trial to a specific TrialRunner.""" 

465 if self._trial_runner_id is None or self._status.is_pending(): 

466 _LOG.debug( 

467 "%sSetting Trial %s to TrialRunner %d", 

468 "Re-" if self._trial_runner_id else "", 

469 self, 

470 trial_runner_id, 

471 ) 

472 self._trial_runner_id = trial_runner_id 

473 else: 

474 _LOG.warning( 

475 "Trial %s already assigned to a TrialRunner, cannot switch to %d", 

476 self, 

477 self._trial_runner_id, 

478 ) 

479 return self._trial_runner_id 

480 

481 def config(self, global_config: dict[str, Any] | None = None) -> dict[str, Any]: 

482 """ 

483 Produce a copy of the global configuration updated with the parameters of 

484 the current trial. 

485 

486 Note: this is not the target Environment's "config" (i.e., tunable 

487 params), but rather the internal "config" which consists of a 

488 combination of somewhat more static variables defined in the json config 

489 files. 

490 """ 

491 config = self._config.copy() 

492 config.update(global_config or {}) 

493 # Here we add some built-in variables for the trial to use while it's running. 

494 config["experiment_id"] = self._experiment_id 

495 config["trial_id"] = self._trial_id 

496 trial_runner_id = self.trial_runner_id 

497 if trial_runner_id is not None: 

498 config["trial_runner_id"] = trial_runner_id 

499 return config 

500 

501 def add_new_config_data( 

502 self, 

503 new_config_data: Mapping[str, int | float | str], 

504 ) -> None: 

505 """ 

506 Add new config data to the trial. 

507 

508 Parameters 

509 ---------- 

510 new_config_data : dict[str, int | float | str] 

511 New data to add (must not already exist for the trial). 

512 

513 Raises 

514 ------ 

515 ValueError 

516 If any of the data already exists. 

517 """ 

518 for key, value in new_config_data.items(): 

519 if key in self._config: 

520 raise ValueError( 

521 f"New config data {key}={value} already exists for trial {self}: " 

522 f"{self._config[key]}" 

523 ) 

524 self._config[key] = value 

525 self._save_new_config_data(new_config_data) 

526 

527 @abstractmethod 

528 def _save_new_config_data( 

529 self, 

530 new_config_data: Mapping[str, int | float | str], 

531 ) -> None: 

532 """ 

533 Save the new config data to the storage. 

534 

535 Parameters 

536 ---------- 

537 new_config_data : dict[str, int | float | str]] 

538 New data to add. 

539 """ 

540 

541 @property 

542 def status(self) -> Status: 

543 """Get the status of the current trial.""" 

544 return self._status 

545 

546 @abstractmethod 

547 def update( 

548 self, 

549 status: Status, 

550 timestamp: datetime, 

551 metrics: dict[str, Any] | None = None, 

552 ) -> dict[str, Any] | None: 

553 """ 

554 Update the storage with the results of the experiment. 

555 

556 Parameters 

557 ---------- 

558 status : Status 

559 Status of the experiment run. 

560 timestamp: datetime.datetime 

561 Timestamp of the status and metrics. 

562 metrics : Optional[dict[str, Any]] 

563 One or several metrics of the experiment run. 

564 Must contain the (float) optimization target if the status is SUCCEEDED. 

565 

566 Returns 

567 ------- 

568 metrics : Optional[dict[str, Any]] 

569 Same as `metrics`, but always in the dict format. 

570 """ 

571 _LOG.info("Store trial: %s :: %s %s", self, status, metrics) 

572 if status.is_succeeded(): 

573 assert metrics is not None 

574 opt_targets = set(self._opt_targets.keys()) 

575 if not opt_targets.issubset(metrics.keys()): 

576 _LOG.warning( 

577 "Trial %s :: opt.targets missing: %s", 

578 self, 

579 opt_targets.difference(metrics.keys()), 

580 ) 

581 # raise ValueError() 

582 self._status = status 

583 return metrics 

584 

585 @abstractmethod 

586 def update_telemetry( 

587 self, 

588 status: Status, 

589 timestamp: datetime, 

590 metrics: list[tuple[datetime, str, Any]], 

591 ) -> None: 

592 """ 

593 Save the experiment's telemetry data and intermediate status. 

594 

595 Parameters 

596 ---------- 

597 status : Status 

598 Current status of the trial. 

599 timestamp: datetime.datetime 

600 Timestamp of the status (but not the metrics). 

601 metrics : list[tuple[datetime.datetime, str, Any]] 

602 Telemetry data. 

603 """ 

604 _LOG.info("Store telemetry: %s :: %s %d records", self, status, len(metrics))