Coverage for mlos_bench/mlos_bench/storage/base_storage.py: 96%

137 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Base interface for saving and restoring the benchmark data.""" 

6 

7import logging 

8from abc import ABCMeta, abstractmethod 

9from datetime import datetime 

10from types import TracebackType 

11from typing import Any, Dict, Iterator, List, Optional, Tuple, Type 

12 

13from typing_extensions import Literal 

14 

15from mlos_bench.config.schemas import ConfigSchema 

16from mlos_bench.dict_templater import DictTemplater 

17from mlos_bench.environments.status import Status 

18from mlos_bench.services.base_service import Service 

19from mlos_bench.storage.base_experiment_data import ExperimentData 

20from mlos_bench.tunables.tunable_groups import TunableGroups 

21from mlos_bench.util import get_git_info 

22 

23_LOG = logging.getLogger(__name__) 

24 

25 

26class Storage(metaclass=ABCMeta): 

27 """An abstract interface between the benchmarking framework and storage systems 

28 (e.g., SQLite or MLFLow). 

29 """ 

30 

31 def __init__( 

32 self, 

33 config: Dict[str, Any], 

34 global_config: Optional[dict] = None, 

35 service: Optional[Service] = None, 

36 ): 

37 """ 

38 Create a new storage object. 

39 

40 Parameters 

41 ---------- 

42 config : dict 

43 Free-format key/value pairs of configuration parameters. 

44 """ 

45 _LOG.debug("Storage config: %s", config) 

46 self._validate_json_config(config) 

47 self._service = service 

48 self._config = config.copy() 

49 self._global_config = global_config or {} 

50 

51 def _validate_json_config(self, config: dict) -> None: 

52 """Reconstructs a basic json config that this class might have been instantiated 

53 from in order to validate configs provided outside the file loading 

54 mechanism. 

55 """ 

56 json_config: dict = { 

57 "class": self.__class__.__module__ + "." + self.__class__.__name__, 

58 } 

59 if config: 

60 json_config["config"] = config 

61 ConfigSchema.STORAGE.validate(json_config) 

62 

63 @property 

64 @abstractmethod 

65 def experiments(self) -> Dict[str, ExperimentData]: 

66 """ 

67 Retrieve the experiments' data from the storage. 

68 

69 Returns 

70 ------- 

71 experiments : Dict[str, ExperimentData] 

72 A dictionary of the experiments' data, keyed by experiment id. 

73 """ 

74 

75 @abstractmethod 

76 def experiment( # pylint: disable=too-many-arguments 

77 self, 

78 *, 

79 experiment_id: str, 

80 trial_id: int, 

81 root_env_config: str, 

82 description: str, 

83 tunables: TunableGroups, 

84 opt_targets: Dict[str, Literal["min", "max"]], 

85 ) -> "Storage.Experiment": 

86 """ 

87 Create a new experiment in the storage. 

88 

89 We need the `opt_target` parameter here to know what metric to retrieve 

90 when we load the data from previous trials. Later we will replace it with 

91 full metadata about the optimization direction, multiple objectives, etc. 

92 

93 Parameters 

94 ---------- 

95 experiment_id : str 

96 Unique identifier of the experiment. 

97 trial_id : int 

98 Starting number of the trial. 

99 root_env_config : str 

100 A path to the root JSON configuration file of the benchmarking environment. 

101 description : str 

102 Human-readable description of the experiment. 

103 tunables : TunableGroups 

104 opt_targets : Dict[str, Literal["min", "max"]] 

105 Names of metrics we're optimizing for and the optimization direction {min, max}. 

106 

107 Returns 

108 ------- 

109 experiment : Storage.Experiment 

110 An object that allows to update the storage with 

111 the results of the experiment and related data. 

112 """ 

113 

114 class Experiment(metaclass=ABCMeta): 

115 # pylint: disable=too-many-instance-attributes 

116 """ 

117 Base interface for storing the results of the experiment. 

118 

119 This class is instantiated in the `Storage.experiment()` method. 

120 """ 

121 

122 def __init__( # pylint: disable=too-many-arguments 

123 self, 

124 *, 

125 tunables: TunableGroups, 

126 experiment_id: str, 

127 trial_id: int, 

128 root_env_config: str, 

129 description: str, 

130 opt_targets: Dict[str, Literal["min", "max"]], 

131 ): 

132 self._tunables = tunables.copy() 

133 self._trial_id = trial_id 

134 self._experiment_id = experiment_id 

135 (self._git_repo, self._git_commit, self._root_env_config) = get_git_info( 

136 root_env_config 

137 ) 

138 self._description = description 

139 self._opt_targets = opt_targets 

140 self._in_context = False 

141 

142 def __enter__(self) -> "Storage.Experiment": 

143 """ 

144 Enter the context of the experiment. 

145 

146 Override the `_setup` method to add custom context initialization. 

147 """ 

148 _LOG.debug("Starting experiment: %s", self) 

149 assert not self._in_context 

150 self._setup() 

151 self._in_context = True 

152 return self 

153 

154 def __exit__( 

155 self, 

156 exc_type: Optional[Type[BaseException]], 

157 exc_val: Optional[BaseException], 

158 exc_tb: Optional[TracebackType], 

159 ) -> Literal[False]: 

160 """ 

161 End the context of the experiment. 

162 

163 Override the `_teardown` method to add custom context teardown logic. 

164 """ 

165 is_ok = exc_val is None 

166 if is_ok: 

167 _LOG.debug("Finishing experiment: %s", self) 

168 else: 

169 assert exc_type and exc_val 

170 _LOG.warning( 

171 "Finishing experiment: %s", 

172 self, 

173 exc_info=(exc_type, exc_val, exc_tb), 

174 ) 

175 assert self._in_context 

176 self._teardown(is_ok) 

177 self._in_context = False 

178 return False # Do not suppress exceptions 

179 

180 def __repr__(self) -> str: 

181 return self._experiment_id 

182 

183 def _setup(self) -> None: 

184 """ 

185 Create a record of the new experiment or find an existing one in the 

186 storage. 

187 

188 This method is called by `Storage.Experiment.__enter__()`. 

189 """ 

190 

191 def _teardown(self, is_ok: bool) -> None: 

192 """ 

193 Finalize the experiment in the storage. 

194 

195 This method is called by `Storage.Experiment.__exit__()`. 

196 

197 Parameters 

198 ---------- 

199 is_ok : bool 

200 True if there were no exceptions during the experiment, False otherwise. 

201 """ 

202 

203 @property 

204 def experiment_id(self) -> str: 

205 """Get the Experiment's ID.""" 

206 return self._experiment_id 

207 

208 @property 

209 def trial_id(self) -> int: 

210 """Get the current Trial ID.""" 

211 return self._trial_id 

212 

213 @property 

214 def description(self) -> str: 

215 """Get the Experiment's description.""" 

216 return self._description 

217 

218 @property 

219 def tunables(self) -> TunableGroups: 

220 """Get the Experiment's tunables.""" 

221 return self._tunables 

222 

223 @property 

224 def opt_targets(self) -> Dict[str, Literal["min", "max"]]: 

225 """Get the Experiment's optimization targets and directions.""" 

226 return self._opt_targets 

227 

228 @abstractmethod 

229 def merge(self, experiment_ids: List[str]) -> None: 

230 """ 

231 Merge in the results of other (compatible) experiments trials. Used to help 

232 warm up the optimizer for this experiment. 

233 

234 Parameters 

235 ---------- 

236 experiment_ids : List[str] 

237 List of IDs of the experiments to merge in. 

238 """ 

239 

240 @abstractmethod 

241 def load_tunable_config(self, config_id: int) -> Dict[str, Any]: 

242 """Load tunable values for a given config ID.""" 

243 

244 @abstractmethod 

245 def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: 

246 """ 

247 Retrieve the telemetry data for a given trial. 

248 

249 Parameters 

250 ---------- 

251 trial_id : int 

252 Trial ID. 

253 

254 Returns 

255 ------- 

256 metrics : List[Tuple[datetime, str, Any]] 

257 Telemetry data. 

258 """ 

259 

260 @abstractmethod 

261 def load( 

262 self, 

263 last_trial_id: int = -1, 

264 ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: 

265 """ 

266 Load (tunable values, benchmark scores, status) to warm-up the optimizer. 

267 

268 If `last_trial_id` is present, load only the data from the (completed) trials 

269 that were scheduled *after* the given trial ID. Otherwise, return data from ALL 

270 merged-in experiments and attempt to impute the missing tunable values. 

271 

272 Parameters 

273 ---------- 

274 last_trial_id : int 

275 (Optional) Trial ID to start from. 

276 

277 Returns 

278 ------- 

279 (trial_ids, configs, scores, status) : ([int], [dict], [Optional[dict]], [Status]) 

280 Trial ids, Tunable values, benchmark scores, and status of the trials. 

281 """ 

282 

283 @abstractmethod 

284 def pending_trials( 

285 self, 

286 timestamp: datetime, 

287 *, 

288 running: bool, 

289 ) -> Iterator["Storage.Trial"]: 

290 """ 

291 Return an iterator over the pending trials that are scheduled to run on or 

292 before the specified timestamp. 

293 

294 Parameters 

295 ---------- 

296 timestamp : datetime 

297 The time in UTC to check for scheduled trials. 

298 running : bool 

299 If True, include the trials that are already running. 

300 Otherwise, return only the scheduled trials. 

301 

302 Returns 

303 ------- 

304 trials : Iterator[Storage.Trial] 

305 An iterator over the scheduled (and maybe running) trials. 

306 """ 

307 

308 def new_trial( 

309 self, 

310 tunables: TunableGroups, 

311 ts_start: Optional[datetime] = None, 

312 config: Optional[Dict[str, Any]] = None, 

313 ) -> "Storage.Trial": 

314 """ 

315 Create a new experiment run in the storage. 

316 

317 Parameters 

318 ---------- 

319 tunables : TunableGroups 

320 Tunable parameters to use for the trial. 

321 ts_start : Optional[datetime] 

322 Timestamp of the trial start (can be in the future). 

323 config : dict 

324 Key/value pairs of additional non-tunable parameters of the trial. 

325 

326 Returns 

327 ------- 

328 trial : Storage.Trial 

329 An object that allows to update the storage with 

330 the results of the experiment trial run. 

331 """ 

332 # Check that `config` is json serializable (e.g., no callables) 

333 if config: 

334 try: 

335 # Relies on the fact that DictTemplater only accepts primitive 

336 # types in it's nested dict structure walk. 

337 _config = DictTemplater(config).expand_vars() 

338 assert isinstance(_config, dict) 

339 except ValueError as e: 

340 _LOG.error("Non-serializable config: %s", config, exc_info=e) 

341 raise e 

342 return self._new_trial(tunables, ts_start, config) 

343 

344 @abstractmethod 

345 def _new_trial( 

346 self, 

347 tunables: TunableGroups, 

348 ts_start: Optional[datetime] = None, 

349 config: Optional[Dict[str, Any]] = None, 

350 ) -> "Storage.Trial": 

351 """ 

352 Create a new experiment run in the storage. 

353 

354 Parameters 

355 ---------- 

356 tunables : TunableGroups 

357 Tunable parameters to use for the trial. 

358 ts_start : Optional[datetime] 

359 Timestamp of the trial start (can be in the future). 

360 config : dict 

361 Key/value pairs of additional non-tunable parameters of the trial. 

362 

363 Returns 

364 ------- 

365 trial : Storage.Trial 

366 An object that allows to update the storage with 

367 the results of the experiment trial run. 

368 """ 

369 

370 class Trial(metaclass=ABCMeta): 

371 # pylint: disable=too-many-instance-attributes 

372 """ 

373 Base interface for storing the results of a single run of the experiment. 

374 

375 This class is instantiated in the `Storage.Experiment.trial()` method. 

376 """ 

377 

378 def __init__( # pylint: disable=too-many-arguments 

379 self, 

380 *, 

381 tunables: TunableGroups, 

382 experiment_id: str, 

383 trial_id: int, 

384 tunable_config_id: int, 

385 opt_targets: Dict[str, Literal["min", "max"]], 

386 config: Optional[Dict[str, Any]] = None, 

387 ): 

388 self._tunables = tunables 

389 self._experiment_id = experiment_id 

390 self._trial_id = trial_id 

391 self._tunable_config_id = tunable_config_id 

392 self._opt_targets = opt_targets 

393 self._config = config or {} 

394 self._status = Status.UNKNOWN 

395 

396 def __repr__(self) -> str: 

397 return f"{self._experiment_id}:{self._trial_id}:{self._tunable_config_id}" 

398 

399 @property 

400 def trial_id(self) -> int: 

401 """ID of the current trial.""" 

402 return self._trial_id 

403 

404 @property 

405 def tunable_config_id(self) -> int: 

406 """ID of the current trial (tunable) configuration.""" 

407 return self._tunable_config_id 

408 

409 @property 

410 def opt_targets(self) -> Dict[str, Literal["min", "max"]]: 

411 """Get the Trial's optimization targets and directions.""" 

412 return self._opt_targets 

413 

414 @property 

415 def tunables(self) -> TunableGroups: 

416 """ 

417 Tunable parameters of the current trial. 

418 

419 (e.g., application Environment's "config") 

420 """ 

421 return self._tunables 

422 

423 def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 

424 """ 

425 Produce a copy of the global configuration updated with the parameters of 

426 the current trial. 

427 

428 Note: this is not the target Environment's "config" (i.e., tunable 

429 params), but rather the internal "config" which consists of a 

430 combination of somewhat more static variables defined in the json config 

431 files. 

432 """ 

433 config = self._config.copy() 

434 config.update(global_config or {}) 

435 config["experiment_id"] = self._experiment_id 

436 config["trial_id"] = self._trial_id 

437 return config 

438 

439 @property 

440 def status(self) -> Status: 

441 """Get the status of the current trial.""" 

442 return self._status 

443 

444 @abstractmethod 

445 def update( 

446 self, 

447 status: Status, 

448 timestamp: datetime, 

449 metrics: Optional[Dict[str, Any]] = None, 

450 ) -> Optional[Dict[str, Any]]: 

451 """ 

452 Update the storage with the results of the experiment. 

453 

454 Parameters 

455 ---------- 

456 status : Status 

457 Status of the experiment run. 

458 timestamp: datetime 

459 Timestamp of the status and metrics. 

460 metrics : Optional[Dict[str, Any]] 

461 One or several metrics of the experiment run. 

462 Must contain the (float) optimization target if the status is SUCCEEDED. 

463 

464 Returns 

465 ------- 

466 metrics : Optional[Dict[str, Any]] 

467 Same as `metrics`, but always in the dict format. 

468 """ 

469 _LOG.info("Store trial: %s :: %s %s", self, status, metrics) 

470 if status.is_succeeded(): 

471 assert metrics is not None 

472 opt_targets = set(self._opt_targets.keys()) 

473 if not opt_targets.issubset(metrics.keys()): 

474 _LOG.warning( 

475 "Trial %s :: opt.targets missing: %s", 

476 self, 

477 opt_targets.difference(metrics.keys()), 

478 ) 

479 # raise ValueError() 

480 self._status = status 

481 return metrics 

482 

483 @abstractmethod 

484 def update_telemetry( 

485 self, 

486 status: Status, 

487 timestamp: datetime, 

488 metrics: List[Tuple[datetime, str, Any]], 

489 ) -> None: 

490 """ 

491 Save the experiment's telemetry data and intermediate status. 

492 

493 Parameters 

494 ---------- 

495 status : Status 

496 Current status of the trial. 

497 timestamp: datetime 

498 Timestamp of the status (but not the metrics). 

499 metrics : List[Tuple[datetime, str, Any]] 

500 Telemetry data. 

501 """ 

502 _LOG.info("Store telemetry: %s :: %s %d records", self, status, len(metrics))