Coverage for mlos_bench/mlos_bench/storage/base_storage.py: 96%
166 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Base interface for saving and restoring the benchmark data.
8See Also
9--------
10mlos_bench.storage.base_storage.Storage.experiments :
11 Retrieves a dictionary of the Experiments' data.
12mlos_bench.storage.base_experiment_data.ExperimentData.results_df :
13 Retrieves a pandas DataFrame of the Experiment's trials' results data.
14mlos_bench.storage.base_experiment_data.ExperimentData.trials :
15 Retrieves a dictionary of the Experiment's trials' data.
16mlos_bench.storage.base_experiment_data.ExperimentData.tunable_configs :
17 Retrieves a dictionary of the Experiment's sampled configs data.
18mlos_bench.storage.base_experiment_data.ExperimentData.tunable_config_trial_groups :
19 Retrieves a dictionary of the Experiment's trials' data, grouped by shared
20 tunable config.
21mlos_bench.storage.base_trial_data.TrialData :
22 Base interface for accessing the stored benchmark trial data.
23"""
25import logging
26from abc import ABCMeta, abstractmethod
27from collections.abc import Iterator, Mapping
28from contextlib import AbstractContextManager as ContextManager
29from datetime import datetime
30from types import TracebackType
31from typing import Any, Literal
33from mlos_bench.config.schemas import ConfigSchema
34from mlos_bench.dict_templater import DictTemplater
35from mlos_bench.environments.status import Status
36from mlos_bench.services.base_service import Service
37from mlos_bench.storage.base_experiment_data import ExperimentData
38from mlos_bench.tunables.tunable_groups import TunableGroups
39from mlos_bench.util import get_git_info
41_LOG = logging.getLogger(__name__)
44class Storage(metaclass=ABCMeta):
45 """An abstract interface between the benchmarking framework and storage systems
46 (e.g., SQLite or MLFLow).
47 """
49 def __init__(
50 self,
51 config: dict[str, Any],
52 global_config: dict | None = None,
53 service: Service | None = None,
54 ):
55 """
56 Create a new storage object.
58 Parameters
59 ----------
60 config : dict
61 Free-format key/value pairs of configuration parameters.
62 """
63 _LOG.debug("Storage config: %s", config)
64 self._validate_json_config(config)
65 self._service = service
66 self._config = config.copy()
67 self._global_config = global_config or {}
69 @abstractmethod
70 def update_schema(self) -> None:
71 """Update the schema of the storage backend if needed."""
73 def _validate_json_config(self, config: dict) -> None:
74 """Reconstructs a basic json config that this class might have been instantiated
75 from in order to validate configs provided outside the file loading
76 mechanism.
77 """
78 json_config: dict = {
79 "class": self.__class__.__module__ + "." + self.__class__.__name__,
80 }
81 if config:
82 json_config["config"] = config
83 ConfigSchema.STORAGE.validate(json_config)
85 @property
86 @abstractmethod
87 def experiments(self) -> dict[str, ExperimentData]:
88 """
89 Retrieve the experiments' data from the storage.
91 Returns
92 -------
93 experiments : dict[str, ExperimentData]
94 A dictionary of the experiments' data, keyed by experiment id.
95 """
97 @abstractmethod
98 def experiment( # pylint: disable=too-many-arguments
99 self,
100 *,
101 experiment_id: str,
102 trial_id: int,
103 root_env_config: str,
104 description: str,
105 tunables: TunableGroups,
106 opt_targets: dict[str, Literal["min", "max"]],
107 ) -> "Storage.Experiment":
108 """
109 Create a new experiment in the storage.
111 We need the `opt_target` parameter here to know what metric to retrieve
112 when we load the data from previous trials. Later we will replace it with
113 full metadata about the optimization direction, multiple objectives, etc.
115 Parameters
116 ----------
117 experiment_id : str
118 Unique identifier of the experiment.
119 trial_id : int
120 Starting number of the trial.
121 root_env_config : str
122 A path to the root JSON configuration file of the benchmarking environment.
123 description : str
124 Human-readable description of the experiment.
125 tunables : TunableGroups
126 opt_targets : dict[str, Literal["min", "max"]]
127 Names of metrics we're optimizing for and the optimization direction {min, max}.
129 Returns
130 -------
131 experiment : Storage.Experiment
132 An object that allows to update the storage with
133 the results of the experiment and related data.
134 """
136 class Experiment(ContextManager, metaclass=ABCMeta):
137 # pylint: disable=too-many-instance-attributes
138 """
139 Base interface for storing the results of the experiment.
141 This class is instantiated in the `Storage.experiment()` method.
142 """
144 def __init__( # pylint: disable=too-many-arguments
145 self,
146 *,
147 tunables: TunableGroups,
148 experiment_id: str,
149 trial_id: int,
150 root_env_config: str,
151 description: str,
152 opt_targets: dict[str, Literal["min", "max"]],
153 ):
154 self._tunables = tunables.copy()
155 self._trial_id = trial_id
156 self._experiment_id = experiment_id
157 (self._git_repo, self._git_commit, self._root_env_config) = get_git_info(
158 root_env_config
159 )
160 self._description = description
161 self._opt_targets = opt_targets
162 self._in_context = False
164 def __enter__(self) -> "Storage.Experiment":
165 """
166 Enter the context of the experiment.
168 Override the `_setup` method to add custom context initialization.
169 """
170 _LOG.debug("Starting experiment: %s", self)
171 assert not self._in_context
172 self._setup()
173 self._in_context = True
174 return self
176 def __exit__(
177 self,
178 exc_type: type[BaseException] | None,
179 exc_val: BaseException | None,
180 exc_tb: TracebackType | None,
181 ) -> Literal[False]:
182 """
183 End the context of the experiment.
185 Override the `_teardown` method to add custom context teardown logic.
186 """
187 is_ok = exc_val is None
188 if is_ok:
189 _LOG.debug("Finishing experiment: %s", self)
190 else:
191 assert exc_type and exc_val
192 _LOG.warning(
193 "Finishing experiment: %s",
194 self,
195 exc_info=(exc_type, exc_val, exc_tb),
196 )
197 assert self._in_context
198 self._teardown(is_ok)
199 self._in_context = False
200 return False # Do not suppress exceptions
202 def __repr__(self) -> str:
203 return self._experiment_id
205 def _setup(self) -> None:
206 """
207 Create a record of the new experiment or find an existing one in the
208 storage.
210 This method is called by `Storage.Experiment.__enter__()`.
211 """
213 def _teardown(self, is_ok: bool) -> None:
214 """
215 Finalize the experiment in the storage.
217 This method is called by `Storage.Experiment.__exit__()`.
219 Parameters
220 ----------
221 is_ok : bool
222 True if there were no exceptions during the experiment, False otherwise.
223 """
225 @property
226 def experiment_id(self) -> str:
227 """Get the Experiment's ID."""
228 return self._experiment_id
230 @property
231 def trial_id(self) -> int:
232 """Get the current Trial ID."""
233 return self._trial_id
235 @property
236 def description(self) -> str:
237 """Get the Experiment's description."""
238 return self._description
240 @property
241 def root_env_config(self) -> str:
242 """Get the Experiment's root Environment config file path."""
243 return self._root_env_config
245 @property
246 def tunables(self) -> TunableGroups:
247 """Get the Experiment's tunables."""
248 return self._tunables
250 @property
251 def opt_targets(self) -> dict[str, Literal["min", "max"]]:
252 """Get the Experiment's optimization targets and directions."""
253 return self._opt_targets
255 @abstractmethod
256 def merge(self, experiment_ids: list[str]) -> None:
257 """
258 Merge in the results of other (compatible) experiments trials. Used to help
259 warm up the optimizer for this experiment.
261 Parameters
262 ----------
263 experiment_ids : list[str]
264 List of IDs of the experiments to merge in.
265 """
267 @abstractmethod
268 def load_tunable_config(self, config_id: int) -> dict[str, Any]:
269 """Load tunable values for a given config ID."""
271 @abstractmethod
272 def load_telemetry(self, trial_id: int) -> list[tuple[datetime, str, Any]]:
273 """
274 Retrieve the telemetry data for a given trial.
276 Parameters
277 ----------
278 trial_id : int
279 Trial ID.
281 Returns
282 -------
283 metrics : list[tuple[datetime.datetime, str, Any]]
284 Telemetry data.
285 """
287 @abstractmethod
288 def load(
289 self,
290 last_trial_id: int = -1,
291 ) -> tuple[list[int], list[dict], list[dict[str, Any] | None], list[Status]]:
292 """
293 Load (tunable values, benchmark scores, status) to warm-up the optimizer.
295 If `last_trial_id` is present, load only the data from the (completed) trials
296 that were scheduled *after* the given trial ID. Otherwise, return data from ALL
297 merged-in experiments and attempt to impute the missing tunable values.
299 Parameters
300 ----------
301 last_trial_id : int
302 (Optional) Trial ID to start from.
304 Returns
305 -------
306 (trial_ids, configs, scores, status) : ([int], [dict], [dict] | None, [Status])
307 Trial ids, Tunable values, benchmark scores, and status of the trials.
308 """
310 @abstractmethod
311 def pending_trials(
312 self,
313 timestamp: datetime,
314 *,
315 running: bool,
316 ) -> Iterator["Storage.Trial"]:
317 """
318 Return an iterator over the pending trials that are scheduled to run on or
319 before the specified timestamp.
321 Parameters
322 ----------
323 timestamp : datetime.datetime
324 The time in UTC to check for scheduled trials.
325 running : bool
326 If True, include the trials that are already running.
327 Otherwise, return only the scheduled trials.
329 Returns
330 -------
331 trials : Iterator[Storage.Trial]
332 An iterator over the scheduled (and maybe running) trials.
333 """
335 def new_trial(
336 self,
337 tunables: TunableGroups,
338 ts_start: datetime | None = None,
339 config: dict[str, Any] | None = None,
340 ) -> "Storage.Trial":
341 """
342 Create a new experiment run in the storage.
344 Parameters
345 ----------
346 tunables : TunableGroups
347 Tunable parameters to use for the trial.
348 ts_start : datetime.datetime | None
349 Timestamp of the trial start (can be in the future).
350 config : dict
351 Key/value pairs of additional non-tunable parameters of the trial.
353 Returns
354 -------
355 trial : Storage.Trial
356 An object that allows to update the storage with
357 the results of the experiment trial run.
358 """
359 # Check that `config` is json serializable (e.g., no callables)
360 if config:
361 try:
362 # Relies on the fact that DictTemplater only accepts primitive
363 # types in it's nested dict structure walk.
364 _config = DictTemplater(config).expand_vars()
365 assert isinstance(_config, dict)
366 except ValueError as e:
367 _LOG.error("Non-serializable config: %s", config, exc_info=e)
368 raise e
369 return self._new_trial(tunables, ts_start, config)
371 @abstractmethod
372 def _new_trial(
373 self,
374 tunables: TunableGroups,
375 ts_start: datetime | None = None,
376 config: dict[str, Any] | None = None,
377 ) -> "Storage.Trial":
378 """
379 Create a new experiment run in the storage.
381 Parameters
382 ----------
383 tunables : TunableGroups
384 Tunable parameters to use for the trial.
385 ts_start : datetime.datetime | None
386 Timestamp of the trial start (can be in the future).
387 config : dict
388 Key/value pairs of additional non-tunable parameters of the trial.
390 Returns
391 -------
392 trial : Storage.Trial
393 An object that allows to update the storage with
394 the results of the experiment trial run.
395 """
397 class Trial(metaclass=ABCMeta):
398 # pylint: disable=too-many-instance-attributes
399 """
400 Base interface for storing the results of a single run of the experiment.
402 This class is instantiated in the `Storage.Experiment.trial()` method.
403 """
405 def __init__( # pylint: disable=too-many-arguments
406 self,
407 *,
408 tunables: TunableGroups,
409 experiment_id: str,
410 trial_id: int,
411 tunable_config_id: int,
412 trial_runner_id: int | None = None,
413 opt_targets: dict[str, Literal["min", "max"]],
414 config: dict[str, Any] | None = None,
415 status: Status = Status.UNKNOWN,
416 ):
417 if status not in (Status.UNKNOWN, Status.PENDING):
418 raise ValueError(f"Invalid status for a new trial: {status}")
419 self._tunables = tunables
420 self._experiment_id = experiment_id
421 self._trial_id = trial_id
422 self._tunable_config_id = tunable_config_id
423 self._trial_runner_id = trial_runner_id
424 self._opt_targets = opt_targets
425 self._config = config or {}
426 self._status = status
428 def __repr__(self) -> str:
429 return (
430 f"{self._experiment_id}:{self._trial_id}:"
431 f"{self._tunable_config_id}:{self.trial_runner_id}"
432 )
434 @property
435 def trial_id(self) -> int:
436 """ID of the current trial."""
437 return self._trial_id
439 @property
440 def tunable_config_id(self) -> int:
441 """ID of the current trial (tunable) configuration."""
442 return self._tunable_config_id
444 @property
445 def trial_runner_id(self) -> int | None:
446 """ID of the TrialRunner this trial is assigned to."""
447 return self._trial_runner_id
449 def opt_targets(self) -> dict[str, Literal["min", "max"]]:
450 """Get the Trial's optimization targets and directions."""
451 return self._opt_targets
453 @property
454 def tunables(self) -> TunableGroups:
455 """
456 Tunable parameters of the current trial.
458 (e.g., application Environment's "config")
459 """
460 return self._tunables
462 @abstractmethod
463 def set_trial_runner(self, trial_runner_id: int) -> int:
464 """Assign the trial to a specific TrialRunner."""
465 if self._trial_runner_id is None or self._status.is_pending():
466 _LOG.debug(
467 "%sSetting Trial %s to TrialRunner %d",
468 "Re-" if self._trial_runner_id else "",
469 self,
470 trial_runner_id,
471 )
472 self._trial_runner_id = trial_runner_id
473 else:
474 _LOG.warning(
475 "Trial %s already assigned to a TrialRunner, cannot switch to %d",
476 self,
477 self._trial_runner_id,
478 )
479 return self._trial_runner_id
481 def config(self, global_config: dict[str, Any] | None = None) -> dict[str, Any]:
482 """
483 Produce a copy of the global configuration updated with the parameters of
484 the current trial.
486 Note: this is not the target Environment's "config" (i.e., tunable
487 params), but rather the internal "config" which consists of a
488 combination of somewhat more static variables defined in the json config
489 files.
490 """
491 config = self._config.copy()
492 config.update(global_config or {})
493 # Here we add some built-in variables for the trial to use while it's running.
494 config["experiment_id"] = self._experiment_id
495 config["trial_id"] = self._trial_id
496 trial_runner_id = self.trial_runner_id
497 if trial_runner_id is not None:
498 config["trial_runner_id"] = trial_runner_id
499 return config
501 def add_new_config_data(
502 self,
503 new_config_data: Mapping[str, int | float | str],
504 ) -> None:
505 """
506 Add new config data to the trial.
508 Parameters
509 ----------
510 new_config_data : dict[str, int | float | str]
511 New data to add (must not already exist for the trial).
513 Raises
514 ------
515 ValueError
516 If any of the data already exists.
517 """
518 for key, value in new_config_data.items():
519 if key in self._config:
520 raise ValueError(
521 f"New config data {key}={value} already exists for trial {self}: "
522 f"{self._config[key]}"
523 )
524 self._config[key] = value
525 self._save_new_config_data(new_config_data)
527 @abstractmethod
528 def _save_new_config_data(
529 self,
530 new_config_data: Mapping[str, int | float | str],
531 ) -> None:
532 """
533 Save the new config data to the storage.
535 Parameters
536 ----------
537 new_config_data : dict[str, int | float | str]]
538 New data to add.
539 """
541 @property
542 def status(self) -> Status:
543 """Get the status of the current trial."""
544 return self._status
546 @abstractmethod
547 def update(
548 self,
549 status: Status,
550 timestamp: datetime,
551 metrics: dict[str, Any] | None = None,
552 ) -> dict[str, Any] | None:
553 """
554 Update the storage with the results of the experiment.
556 Parameters
557 ----------
558 status : Status
559 Status of the experiment run.
560 timestamp: datetime.datetime
561 Timestamp of the status and metrics.
562 metrics : Optional[dict[str, Any]]
563 One or several metrics of the experiment run.
564 Must contain the (float) optimization target if the status is SUCCEEDED.
566 Returns
567 -------
568 metrics : Optional[dict[str, Any]]
569 Same as `metrics`, but always in the dict format.
570 """
571 _LOG.info("Store trial: %s :: %s %s", self, status, metrics)
572 if status.is_succeeded():
573 assert metrics is not None
574 opt_targets = set(self._opt_targets.keys())
575 if not opt_targets.issubset(metrics.keys()):
576 _LOG.warning(
577 "Trial %s :: opt.targets missing: %s",
578 self,
579 opt_targets.difference(metrics.keys()),
580 )
581 # raise ValueError()
582 self._status = status
583 return metrics
585 @abstractmethod
586 def update_telemetry(
587 self,
588 status: Status,
589 timestamp: datetime,
590 metrics: list[tuple[datetime, str, Any]],
591 ) -> None:
592 """
593 Save the experiment's telemetry data and intermediate status.
595 Parameters
596 ----------
597 status : Status
598 Current status of the trial.
599 timestamp: datetime.datetime
600 Timestamp of the status (but not the metrics).
601 metrics : list[tuple[datetime.datetime, str, Any]]
602 Telemetry data.
603 """
604 _LOG.info("Store telemetry: %s :: %s %d records", self, status, len(metrics))