Coverage for mlos_bench/mlos_bench/storage/base_storage.py: 97%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Base interface for saving and restoring the benchmark data. 

7 

8See Also 

9-------- 

10mlos_bench.storage.base_storage.Storage.experiments : 

11 Retrieves a dictionary of the Experiments' data. 

12mlos_bench.storage.base_experiment_data.ExperimentData.results_df : 

13 Retrieves a pandas DataFrame of the Experiment's trials' results data. 

14mlos_bench.storage.base_experiment_data.ExperimentData.trials : 

15 Retrieves a dictionary of the Experiment's trials' data. 

16mlos_bench.storage.base_experiment_data.ExperimentData.tunable_configs : 

17 Retrieves a dictionary of the Experiment's sampled configs data. 

18mlos_bench.storage.base_experiment_data.ExperimentData.tunable_config_trial_groups : 

19 Retrieves a dictionary of the Experiment's trials' data, grouped by shared 

20 tunable config. 

21mlos_bench.storage.base_trial_data.TrialData : 

22 Base interface for accessing the stored benchmark trial data. 

23""" 

24 

25import logging 

26from abc import ABCMeta, abstractmethod 

27from datetime import datetime 

28from types import TracebackType 

29from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Type 

30 

31from mlos_bench.config.schemas import ConfigSchema 

32from mlos_bench.dict_templater import DictTemplater 

33from mlos_bench.environments.status import Status 

34from mlos_bench.services.base_service import Service 

35from mlos_bench.storage.base_experiment_data import ExperimentData 

36from mlos_bench.tunables.tunable_groups import TunableGroups 

37from mlos_bench.util import get_git_info 

38 

39_LOG = logging.getLogger(__name__) 

40 

41 

42class Storage(metaclass=ABCMeta): 

43 """An abstract interface between the benchmarking framework and storage systems 

44 (e.g., SQLite or MLFLow). 

45 """ 

46 

47 def __init__( 

48 self, 

49 config: Dict[str, Any], 

50 global_config: Optional[dict] = None, 

51 service: Optional[Service] = None, 

52 ): 

53 """ 

54 Create a new storage object. 

55 

56 Parameters 

57 ---------- 

58 config : dict 

59 Free-format key/value pairs of configuration parameters. 

60 """ 

61 _LOG.debug("Storage config: %s", config) 

62 self._validate_json_config(config) 

63 self._service = service 

64 self._config = config.copy() 

65 self._global_config = global_config or {} 

66 

67 def _validate_json_config(self, config: dict) -> None: 

68 """Reconstructs a basic json config that this class might have been instantiated 

69 from in order to validate configs provided outside the file loading 

70 mechanism. 

71 """ 

72 json_config: dict = { 

73 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

74 } 

75 if config: 

76 json_config["config"] = config 

77 ConfigSchema.STORAGE.validate(json_config) 

78 

79 @property 

80 @abstractmethod 

81 def experiments(self) -> Dict[str, ExperimentData]: 

82 """ 

83 Retrieve the experiments' data from the storage. 

84 

85 Returns 

86 ------- 

87 experiments : Dict[str, ExperimentData] 

88 A dictionary of the experiments' data, keyed by experiment id. 

89 """ 

90 

91 @abstractmethod 

92 def experiment( # pylint: disable=too-many-arguments 

93 self, 

94 *, 

95 experiment_id: str, 

96 trial_id: int, 

97 root_env_config: str, 

98 description: str, 

99 tunables: TunableGroups, 

100 opt_targets: Dict[str, Literal["min", "max"]], 

101 ) -> "Storage.Experiment": 

102 """ 

103 Create a new experiment in the storage. 

104 

105 We need the `opt_target` parameter here to know what metric to retrieve 

106 when we load the data from previous trials. Later we will replace it with 

107 full metadata about the optimization direction, multiple objectives, etc. 

108 

109 Parameters 

110 ---------- 

111 experiment_id : str 

112 Unique identifier of the experiment. 

113 trial_id : int 

114 Starting number of the trial. 

115 root_env_config : str 

116 A path to the root JSON configuration file of the benchmarking environment. 

117 description : str 

118 Human-readable description of the experiment. 

119 tunables : TunableGroups 

120 opt_targets : Dict[str, Literal["min", "max"]] 

121 Names of metrics we're optimizing for and the optimization direction {min, max}. 

122 

123 Returns 

124 ------- 

125 experiment : Storage.Experiment 

126 An object that allows to update the storage with 

127 the results of the experiment and related data. 

128 """ 

129 

130 class Experiment(metaclass=ABCMeta): 

131 # pylint: disable=too-many-instance-attributes 

132 """ 

133 Base interface for storing the results of the experiment. 

134 

135 This class is instantiated in the `Storage.experiment()` method. 

136 """ 

137 

138 def __init__( # pylint: disable=too-many-arguments 

139 self, 

140 *, 

141 tunables: TunableGroups, 

142 experiment_id: str, 

143 trial_id: int, 

144 root_env_config: str, 

145 description: str, 

146 opt_targets: Dict[str, Literal["min", "max"]], 

147 ): 

148 self._tunables = tunables.copy() 

149 self._trial_id = trial_id 

150 self._experiment_id = experiment_id 

151 (self._git_repo, self._git_commit, self._root_env_config) = get_git_info( 

152 root_env_config 

153 ) 

154 self._description = description 

155 self._opt_targets = opt_targets 

156 self._in_context = False 

157 

158 def __enter__(self) -> "Storage.Experiment": 

159 """ 

160 Enter the context of the experiment. 

161 

162 Override the `_setup` method to add custom context initialization. 

163 """ 

164 _LOG.debug("Starting experiment: %s", self) 

165 assert not self._in_context 

166 self._setup() 

167 self._in_context = True 

168 return self 

169 

170 def __exit__( 

171 self, 

172 exc_type: Optional[Type[BaseException]], 

173 exc_val: Optional[BaseException], 

174 exc_tb: Optional[TracebackType], 

175 ) -> Literal[False]: 

176 """ 

177 End the context of the experiment. 

178 

179 Override the `_teardown` method to add custom context teardown logic. 

180 """ 

181 is_ok = exc_val is None 

182 if is_ok: 

183 _LOG.debug("Finishing experiment: %s", self) 

184 else: 

185 assert exc_type and exc_val 

186 _LOG.warning( 

187 "Finishing experiment: %s", 

188 self, 

189 exc_info=(exc_type, exc_val, exc_tb), 

190 ) 

191 assert self._in_context 

192 self._teardown(is_ok) 

193 self._in_context = False 

194 return False # Do not suppress exceptions 

195 

196 def __repr__(self) -> str: 

197 return self._experiment_id 

198 

199 def _setup(self) -> None: 

200 """ 

201 Create a record of the new experiment or find an existing one in the 

202 storage. 

203 

204 This method is called by `Storage.Experiment.__enter__()`. 

205 """ 

206 

207 def _teardown(self, is_ok: bool) -> None: 

208 """ 

209 Finalize the experiment in the storage. 

210 

211 This method is called by `Storage.Experiment.__exit__()`. 

212 

213 Parameters 

214 ---------- 

215 is_ok : bool 

216 True if there were no exceptions during the experiment, False otherwise. 

217 """ 

218 

219 @property 

220 def experiment_id(self) -> str: 

221 """Get the Experiment's ID.""" 

222 return self._experiment_id 

223 

224 @property 

225 def trial_id(self) -> int: 

226 """Get the current Trial ID.""" 

227 return self._trial_id 

228 

229 @property 

230 def description(self) -> str: 

231 """Get the Experiment's description.""" 

232 return self._description 

233 

234 @property 

235 def root_env_config(self) -> str: 

236 """Get the Experiment's root Environment config file path.""" 

237 return self._root_env_config 

238 

239 @property 

240 def tunables(self) -> TunableGroups: 

241 """Get the Experiment's tunables.""" 

242 return self._tunables 

243 

244 @property 

245 def opt_targets(self) -> Dict[str, Literal["min", "max"]]: 

246 """Get the Experiment's optimization targets and directions.""" 

247 return self._opt_targets 

248 

249 @abstractmethod 

250 def merge(self, experiment_ids: List[str]) -> None: 

251 """ 

252 Merge in the results of other (compatible) experiments trials. Used to help 

253 warm up the optimizer for this experiment. 

254 

255 Parameters 

256 ---------- 

257 experiment_ids : List[str] 

258 List of IDs of the experiments to merge in. 

259 """ 

260 

261 @abstractmethod 

262 def load_tunable_config(self, config_id: int) -> Dict[str, Any]: 

263 """Load tunable values for a given config ID.""" 

264 

265 @abstractmethod 

266 def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: 

267 """ 

268 Retrieve the telemetry data for a given trial. 

269 

270 Parameters 

271 ---------- 

272 trial_id : int 

273 Trial ID. 

274 

275 Returns 

276 ------- 

277 metrics : List[Tuple[datetime.datetime, str, Any]] 

278 Telemetry data. 

279 """ 

280 

281 @abstractmethod 

282 def load( 

283 self, 

284 last_trial_id: int = -1, 

285 ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: 

286 """ 

287 Load (tunable values, benchmark scores, status) to warm-up the optimizer. 

288 

289 If `last_trial_id` is present, load only the data from the (completed) trials 

290 that were scheduled *after* the given trial ID. Otherwise, return data from ALL 

291 merged-in experiments and attempt to impute the missing tunable values. 

292 

293 Parameters 

294 ---------- 

295 last_trial_id : int 

296 (Optional) Trial ID to start from. 

297 

298 Returns 

299 ------- 

300 (trial_ids, configs, scores, status) : ([int], [dict], [Optional[dict]], [Status]) 

301 Trial ids, Tunable values, benchmark scores, and status of the trials. 

302 """ 

303 

304 @abstractmethod 

305 def pending_trials( 

306 self, 

307 timestamp: datetime, 

308 *, 

309 running: bool, 

310 ) -> Iterator["Storage.Trial"]: 

311 """ 

312 Return an iterator over the pending trials that are scheduled to run on or 

313 before the specified timestamp. 

314 

315 Parameters 

316 ---------- 

317 timestamp : datetime.datetime 

318 The time in UTC to check for scheduled trials. 

319 running : bool 

320 If True, include the trials that are already running. 

321 Otherwise, return only the scheduled trials. 

322 

323 Returns 

324 ------- 

325 trials : Iterator[Storage.Trial] 

326 An iterator over the scheduled (and maybe running) trials. 

327 """ 

328 

329 def new_trial( 

330 self, 

331 tunables: TunableGroups, 

332 ts_start: Optional[datetime] = None, 

333 config: Optional[Dict[str, Any]] = None, 

334 ) -> "Storage.Trial": 

335 """ 

336 Create a new experiment run in the storage. 

337 

338 Parameters 

339 ---------- 

340 tunables : TunableGroups 

341 Tunable parameters to use for the trial. 

342 ts_start : Optional[datetime.datetime] 

343 Timestamp of the trial start (can be in the future). 

344 config : dict 

345 Key/value pairs of additional non-tunable parameters of the trial. 

346 

347 Returns 

348 ------- 

349 trial : Storage.Trial 

350 An object that allows to update the storage with 

351 the results of the experiment trial run. 

352 """ 

353 # Check that `config` is json serializable (e.g., no callables) 

354 if config: 

355 try: 

356 # Relies on the fact that DictTemplater only accepts primitive 

357 # types in it's nested dict structure walk. 

358 _config = DictTemplater(config).expand_vars() 

359 assert isinstance(_config, dict) 

360 except ValueError as e: 

361 _LOG.error("Non-serializable config: %s", config, exc_info=e) 

362 raise e 

363 return self._new_trial(tunables, ts_start, config) 

364 

365 @abstractmethod 

366 def _new_trial( 

367 self, 

368 tunables: TunableGroups, 

369 ts_start: Optional[datetime] = None, 

370 config: Optional[Dict[str, Any]] = None, 

371 ) -> "Storage.Trial": 

372 """ 

373 Create a new experiment run in the storage. 

374 

375 Parameters 

376 ---------- 

377 tunables : TunableGroups 

378 Tunable parameters to use for the trial. 

379 ts_start : Optional[datetime.datetime] 

380 Timestamp of the trial start (can be in the future). 

381 config : dict 

382 Key/value pairs of additional non-tunable parameters of the trial. 

383 

384 Returns 

385 ------- 

386 trial : Storage.Trial 

387 An object that allows to update the storage with 

388 the results of the experiment trial run. 

389 """ 

390 

391 class Trial(metaclass=ABCMeta): 

392 # pylint: disable=too-many-instance-attributes 

393 """ 

394 Base interface for storing the results of a single run of the experiment. 

395 

396 This class is instantiated in the `Storage.Experiment.trial()` method. 

397 """ 

398 

399 def __init__( # pylint: disable=too-many-arguments 

400 self, 

401 *, 

402 tunables: TunableGroups, 

403 experiment_id: str, 

404 trial_id: int, 

405 tunable_config_id: int, 

406 opt_targets: Dict[str, Literal["min", "max"]], 

407 config: Optional[Dict[str, Any]] = None, 

408 ): 

409 self._tunables = tunables 

410 self._experiment_id = experiment_id 

411 self._trial_id = trial_id 

412 self._tunable_config_id = tunable_config_id 

413 self._opt_targets = opt_targets 

414 self._config = config or {} 

415 self._status = Status.UNKNOWN 

416 

417 def __repr__(self) -> str: 

418 return f"{self._experiment_id}:{self._trial_id}:{self._tunable_config_id}" 

419 

420 @property 

421 def trial_id(self) -> int: 

422 """ID of the current trial.""" 

423 return self._trial_id 

424 

425 @property 

426 def tunable_config_id(self) -> int: 

427 """ID of the current trial (tunable) configuration.""" 

428 return self._tunable_config_id 

429 

430 @property 

431 def opt_targets(self) -> Dict[str, Literal["min", "max"]]: 

432 """Get the Trial's optimization targets and directions.""" 

433 return self._opt_targets 

434 

435 @property 

436 def tunables(self) -> TunableGroups: 

437 """ 

438 Tunable parameters of the current trial. 

439 

440 (e.g., application Environment's "config") 

441 """ 

442 return self._tunables 

443 

444 def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 

445 """ 

446 Produce a copy of the global configuration updated with the parameters of 

447 the current trial. 

448 

449 Note: this is not the target Environment's "config" (i.e., tunable 

450 params), but rather the internal "config" which consists of a 

451 combination of somewhat more static variables defined in the json config 

452 files. 

453 """ 

454 config = self._config.copy() 

455 config.update(global_config or {}) 

456 config["experiment_id"] = self._experiment_id 

457 config["trial_id"] = self._trial_id 

458 return config 

459 

460 @property 

461 def status(self) -> Status: 

462 """Get the status of the current trial.""" 

463 return self._status 

464 

465 @abstractmethod 

466 def update( 

467 self, 

468 status: Status, 

469 timestamp: datetime, 

470 metrics: Optional[Dict[str, Any]] = None, 

471 ) -> Optional[Dict[str, Any]]: 

472 """ 

473 Update the storage with the results of the experiment. 

474 

475 Parameters 

476 ---------- 

477 status : Status 

478 Status of the experiment run. 

479 timestamp: datetime.datetime 

480 Timestamp of the status and metrics. 

481 metrics : Optional[Dict[str, Any]] 

482 One or several metrics of the experiment run. 

483 Must contain the (float) optimization target if the status is SUCCEEDED. 

484 

485 Returns 

486 ------- 

487 metrics : Optional[Dict[str, Any]] 

488 Same as `metrics`, but always in the dict format. 

489 """ 

490 _LOG.info("Store trial: %s :: %s %s", self, status, metrics) 

491 if status.is_succeeded(): 

492 assert metrics is not None 

493 opt_targets = set(self._opt_targets.keys()) 

494 if not opt_targets.issubset(metrics.keys()): 

495 _LOG.warning( 

496 "Trial %s :: opt.targets missing: %s", 

497 self, 

498 opt_targets.difference(metrics.keys()), 

499 ) 

500 # raise ValueError() 

501 self._status = status 

502 return metrics 

503 

504 @abstractmethod 

505 def update_telemetry( 

506 self, 

507 status: Status, 

508 timestamp: datetime, 

509 metrics: List[Tuple[datetime, str, Any]], 

510 ) -> None: 

511 """ 

512 Save the experiment's telemetry data and intermediate status. 

513 

514 Parameters 

515 ---------- 

516 status : Status 

517 Current status of the trial. 

518 timestamp: datetime.datetime 

519 Timestamp of the status (but not the metrics). 

520 metrics : List[Tuple[datetime.datetime, str, Any]] 

521 Telemetry data. 

522 """ 

523 _LOG.info("Store telemetry: %s :: %s %d records", self, status, len(metrics))