Coverage for mlos_bench/mlos_bench/storage/base_storage.py: 96%
137 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Base interface for saving and restoring the benchmark data."""
7import logging
8from abc import ABCMeta, abstractmethod
9from datetime import datetime
10from types import TracebackType
11from typing import Any, Dict, Iterator, List, Optional, Tuple, Type
13from typing_extensions import Literal
15from mlos_bench.config.schemas import ConfigSchema
16from mlos_bench.dict_templater import DictTemplater
17from mlos_bench.environments.status import Status
18from mlos_bench.services.base_service import Service
19from mlos_bench.storage.base_experiment_data import ExperimentData
20from mlos_bench.tunables.tunable_groups import TunableGroups
21from mlos_bench.util import get_git_info
23_LOG = logging.getLogger(__name__)
26class Storage(metaclass=ABCMeta):
27 """An abstract interface between the benchmarking framework and storage systems
28 (e.g., SQLite or MLFLow).
29 """
31 def __init__(
32 self,
33 config: Dict[str, Any],
34 global_config: Optional[dict] = None,
35 service: Optional[Service] = None,
36 ):
37 """
38 Create a new storage object.
40 Parameters
41 ----------
42 config : dict
43 Free-format key/value pairs of configuration parameters.
44 """
45 _LOG.debug("Storage config: %s", config)
46 self._validate_json_config(config)
47 self._service = service
48 self._config = config.copy()
49 self._global_config = global_config or {}
51 def _validate_json_config(self, config: dict) -> None:
52 """Reconstructs a basic json config that this class might have been instantiated
53 from in order to validate configs provided outside the file loading
54 mechanism.
55 """
56 json_config: dict = {
57 "class": self.__class__.__module__ + "." + self.__class__.__name__,
58 }
59 if config:
60 json_config["config"] = config
61 ConfigSchema.STORAGE.validate(json_config)
63 @property
64 @abstractmethod
65 def experiments(self) -> Dict[str, ExperimentData]:
66 """
67 Retrieve the experiments' data from the storage.
69 Returns
70 -------
71 experiments : Dict[str, ExperimentData]
72 A dictionary of the experiments' data, keyed by experiment id.
73 """
75 @abstractmethod
76 def experiment( # pylint: disable=too-many-arguments
77 self,
78 *,
79 experiment_id: str,
80 trial_id: int,
81 root_env_config: str,
82 description: str,
83 tunables: TunableGroups,
84 opt_targets: Dict[str, Literal["min", "max"]],
85 ) -> "Storage.Experiment":
86 """
87 Create a new experiment in the storage.
89 We need the `opt_target` parameter here to know what metric to retrieve
90 when we load the data from previous trials. Later we will replace it with
91 full metadata about the optimization direction, multiple objectives, etc.
93 Parameters
94 ----------
95 experiment_id : str
96 Unique identifier of the experiment.
97 trial_id : int
98 Starting number of the trial.
99 root_env_config : str
100 A path to the root JSON configuration file of the benchmarking environment.
101 description : str
102 Human-readable description of the experiment.
103 tunables : TunableGroups
104 opt_targets : Dict[str, Literal["min", "max"]]
105 Names of metrics we're optimizing for and the optimization direction {min, max}.
107 Returns
108 -------
109 experiment : Storage.Experiment
110 An object that allows to update the storage with
111 the results of the experiment and related data.
112 """
114 class Experiment(metaclass=ABCMeta):
115 # pylint: disable=too-many-instance-attributes
116 """
117 Base interface for storing the results of the experiment.
119 This class is instantiated in the `Storage.experiment()` method.
120 """
122 def __init__( # pylint: disable=too-many-arguments
123 self,
124 *,
125 tunables: TunableGroups,
126 experiment_id: str,
127 trial_id: int,
128 root_env_config: str,
129 description: str,
130 opt_targets: Dict[str, Literal["min", "max"]],
131 ):
132 self._tunables = tunables.copy()
133 self._trial_id = trial_id
134 self._experiment_id = experiment_id
135 (self._git_repo, self._git_commit, self._root_env_config) = get_git_info(
136 root_env_config
137 )
138 self._description = description
139 self._opt_targets = opt_targets
140 self._in_context = False
142 def __enter__(self) -> "Storage.Experiment":
143 """
144 Enter the context of the experiment.
146 Override the `_setup` method to add custom context initialization.
147 """
148 _LOG.debug("Starting experiment: %s", self)
149 assert not self._in_context
150 self._setup()
151 self._in_context = True
152 return self
154 def __exit__(
155 self,
156 exc_type: Optional[Type[BaseException]],
157 exc_val: Optional[BaseException],
158 exc_tb: Optional[TracebackType],
159 ) -> Literal[False]:
160 """
161 End the context of the experiment.
163 Override the `_teardown` method to add custom context teardown logic.
164 """
165 is_ok = exc_val is None
166 if is_ok:
167 _LOG.debug("Finishing experiment: %s", self)
168 else:
169 assert exc_type and exc_val
170 _LOG.warning(
171 "Finishing experiment: %s",
172 self,
173 exc_info=(exc_type, exc_val, exc_tb),
174 )
175 assert self._in_context
176 self._teardown(is_ok)
177 self._in_context = False
178 return False # Do not suppress exceptions
180 def __repr__(self) -> str:
181 return self._experiment_id
183 def _setup(self) -> None:
184 """
185 Create a record of the new experiment or find an existing one in the
186 storage.
188 This method is called by `Storage.Experiment.__enter__()`.
189 """
191 def _teardown(self, is_ok: bool) -> None:
192 """
193 Finalize the experiment in the storage.
195 This method is called by `Storage.Experiment.__exit__()`.
197 Parameters
198 ----------
199 is_ok : bool
200 True if there were no exceptions during the experiment, False otherwise.
201 """
203 @property
204 def experiment_id(self) -> str:
205 """Get the Experiment's ID."""
206 return self._experiment_id
208 @property
209 def trial_id(self) -> int:
210 """Get the current Trial ID."""
211 return self._trial_id
213 @property
214 def description(self) -> str:
215 """Get the Experiment's description."""
216 return self._description
218 @property
219 def tunables(self) -> TunableGroups:
220 """Get the Experiment's tunables."""
221 return self._tunables
223 @property
224 def opt_targets(self) -> Dict[str, Literal["min", "max"]]:
225 """Get the Experiment's optimization targets and directions."""
226 return self._opt_targets
228 @abstractmethod
229 def merge(self, experiment_ids: List[str]) -> None:
230 """
231 Merge in the results of other (compatible) experiments trials. Used to help
232 warm up the optimizer for this experiment.
234 Parameters
235 ----------
236 experiment_ids : List[str]
237 List of IDs of the experiments to merge in.
238 """
240 @abstractmethod
241 def load_tunable_config(self, config_id: int) -> Dict[str, Any]:
242 """Load tunable values for a given config ID."""
244 @abstractmethod
245 def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]:
246 """
247 Retrieve the telemetry data for a given trial.
249 Parameters
250 ----------
251 trial_id : int
252 Trial ID.
254 Returns
255 -------
256 metrics : List[Tuple[datetime, str, Any]]
257 Telemetry data.
258 """
260 @abstractmethod
261 def load(
262 self,
263 last_trial_id: int = -1,
264 ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
265 """
266 Load (tunable values, benchmark scores, status) to warm-up the optimizer.
268 If `last_trial_id` is present, load only the data from the (completed) trials
269 that were scheduled *after* the given trial ID. Otherwise, return data from ALL
270 merged-in experiments and attempt to impute the missing tunable values.
272 Parameters
273 ----------
274 last_trial_id : int
275 (Optional) Trial ID to start from.
277 Returns
278 -------
279 (trial_ids, configs, scores, status) : ([int], [dict], [Optional[dict]], [Status])
280 Trial ids, Tunable values, benchmark scores, and status of the trials.
281 """
283 @abstractmethod
284 def pending_trials(
285 self,
286 timestamp: datetime,
287 *,
288 running: bool,
289 ) -> Iterator["Storage.Trial"]:
290 """
291 Return an iterator over the pending trials that are scheduled to run on or
292 before the specified timestamp.
294 Parameters
295 ----------
296 timestamp : datetime
297 The time in UTC to check for scheduled trials.
298 running : bool
299 If True, include the trials that are already running.
300 Otherwise, return only the scheduled trials.
302 Returns
303 -------
304 trials : Iterator[Storage.Trial]
305 An iterator over the scheduled (and maybe running) trials.
306 """
308 def new_trial(
309 self,
310 tunables: TunableGroups,
311 ts_start: Optional[datetime] = None,
312 config: Optional[Dict[str, Any]] = None,
313 ) -> "Storage.Trial":
314 """
315 Create a new experiment run in the storage.
317 Parameters
318 ----------
319 tunables : TunableGroups
320 Tunable parameters to use for the trial.
321 ts_start : Optional[datetime]
322 Timestamp of the trial start (can be in the future).
323 config : dict
324 Key/value pairs of additional non-tunable parameters of the trial.
326 Returns
327 -------
328 trial : Storage.Trial
329 An object that allows to update the storage with
330 the results of the experiment trial run.
331 """
332 # Check that `config` is json serializable (e.g., no callables)
333 if config:
334 try:
335 # Relies on the fact that DictTemplater only accepts primitive
336 # types in it's nested dict structure walk.
337 _config = DictTemplater(config).expand_vars()
338 assert isinstance(_config, dict)
339 except ValueError as e:
340 _LOG.error("Non-serializable config: %s", config, exc_info=e)
341 raise e
342 return self._new_trial(tunables, ts_start, config)
344 @abstractmethod
345 def _new_trial(
346 self,
347 tunables: TunableGroups,
348 ts_start: Optional[datetime] = None,
349 config: Optional[Dict[str, Any]] = None,
350 ) -> "Storage.Trial":
351 """
352 Create a new experiment run in the storage.
354 Parameters
355 ----------
356 tunables : TunableGroups
357 Tunable parameters to use for the trial.
358 ts_start : Optional[datetime]
359 Timestamp of the trial start (can be in the future).
360 config : dict
361 Key/value pairs of additional non-tunable parameters of the trial.
363 Returns
364 -------
365 trial : Storage.Trial
366 An object that allows to update the storage with
367 the results of the experiment trial run.
368 """
370 class Trial(metaclass=ABCMeta):
371 # pylint: disable=too-many-instance-attributes
372 """
373 Base interface for storing the results of a single run of the experiment.
375 This class is instantiated in the `Storage.Experiment.trial()` method.
376 """
378 def __init__( # pylint: disable=too-many-arguments
379 self,
380 *,
381 tunables: TunableGroups,
382 experiment_id: str,
383 trial_id: int,
384 tunable_config_id: int,
385 opt_targets: Dict[str, Literal["min", "max"]],
386 config: Optional[Dict[str, Any]] = None,
387 ):
388 self._tunables = tunables
389 self._experiment_id = experiment_id
390 self._trial_id = trial_id
391 self._tunable_config_id = tunable_config_id
392 self._opt_targets = opt_targets
393 self._config = config or {}
394 self._status = Status.UNKNOWN
396 def __repr__(self) -> str:
397 return f"{self._experiment_id}:{self._trial_id}:{self._tunable_config_id}"
399 @property
400 def trial_id(self) -> int:
401 """ID of the current trial."""
402 return self._trial_id
404 @property
405 def tunable_config_id(self) -> int:
406 """ID of the current trial (tunable) configuration."""
407 return self._tunable_config_id
409 @property
410 def opt_targets(self) -> Dict[str, Literal["min", "max"]]:
411 """Get the Trial's optimization targets and directions."""
412 return self._opt_targets
414 @property
415 def tunables(self) -> TunableGroups:
416 """
417 Tunable parameters of the current trial.
419 (e.g., application Environment's "config")
420 """
421 return self._tunables
423 def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
424 """
425 Produce a copy of the global configuration updated with the parameters of
426 the current trial.
428 Note: this is not the target Environment's "config" (i.e., tunable
429 params), but rather the internal "config" which consists of a
430 combination of somewhat more static variables defined in the json config
431 files.
432 """
433 config = self._config.copy()
434 config.update(global_config or {})
435 config["experiment_id"] = self._experiment_id
436 config["trial_id"] = self._trial_id
437 return config
439 @property
440 def status(self) -> Status:
441 """Get the status of the current trial."""
442 return self._status
444 @abstractmethod
445 def update(
446 self,
447 status: Status,
448 timestamp: datetime,
449 metrics: Optional[Dict[str, Any]] = None,
450 ) -> Optional[Dict[str, Any]]:
451 """
452 Update the storage with the results of the experiment.
454 Parameters
455 ----------
456 status : Status
457 Status of the experiment run.
458 timestamp: datetime
459 Timestamp of the status and metrics.
460 metrics : Optional[Dict[str, Any]]
461 One or several metrics of the experiment run.
462 Must contain the (float) optimization target if the status is SUCCEEDED.
464 Returns
465 -------
466 metrics : Optional[Dict[str, Any]]
467 Same as `metrics`, but always in the dict format.
468 """
469 _LOG.info("Store trial: %s :: %s %s", self, status, metrics)
470 if status.is_succeeded():
471 assert metrics is not None
472 opt_targets = set(self._opt_targets.keys())
473 if not opt_targets.issubset(metrics.keys()):
474 _LOG.warning(
475 "Trial %s :: opt.targets missing: %s",
476 self,
477 opt_targets.difference(metrics.keys()),
478 )
479 # raise ValueError()
480 self._status = status
481 return metrics
483 @abstractmethod
484 def update_telemetry(
485 self,
486 status: Status,
487 timestamp: datetime,
488 metrics: List[Tuple[datetime, str, Any]],
489 ) -> None:
490 """
491 Save the experiment's telemetry data and intermediate status.
493 Parameters
494 ----------
495 status : Status
496 Current status of the trial.
497 timestamp: datetime
498 Timestamp of the status (but not the metrics).
499 metrics : List[Tuple[datetime, str, Any]]
500 Telemetry data.
501 """
502 _LOG.info("Store telemetry: %s :: %s %d records", self, status, len(metrics))