Coverage for mlos_bench/mlos_bench/storage/sql/storage.py: 100%
38 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Saving and restoring the benchmark data in SQL database."""
7import logging
8from typing import Dict, Literal, Optional
10from sqlalchemy import URL, create_engine
12from mlos_bench.services.base_service import Service
13from mlos_bench.storage.base_experiment_data import ExperimentData
14from mlos_bench.storage.base_storage import Storage
15from mlos_bench.storage.sql.experiment import Experiment
16from mlos_bench.storage.sql.experiment_data import ExperimentSqlData
17from mlos_bench.storage.sql.schema import DbSchema
18from mlos_bench.tunables.tunable_groups import TunableGroups
20_LOG = logging.getLogger(__name__)
23class SqlStorage(Storage):
24 """An implementation of the Storage interface using SQLAlchemy backend."""
26 def __init__(
27 self,
28 config: dict,
29 global_config: Optional[dict] = None,
30 service: Optional[Service] = None,
31 ):
32 super().__init__(config, global_config, service)
33 lazy_schema_create = self._config.pop("lazy_schema_create", False)
34 self._log_sql = self._config.pop("log_sql", False)
35 self._url = URL.create(**self._config)
36 self._repr = f"{self._url.get_backend_name()}:{self._url.database}"
37 _LOG.info("Connect to the database: %s", self)
38 self._engine = create_engine(self._url, echo=self._log_sql)
39 self._db_schema: DbSchema
40 if not lazy_schema_create:
41 assert self._schema
42 else:
43 _LOG.info("Using lazy schema create for database: %s", self)
45 @property
46 def _schema(self) -> DbSchema:
47 """Lazily create schema upon first access."""
48 if not hasattr(self, "_db_schema"):
49 self._db_schema = DbSchema(self._engine).create()
50 if _LOG.isEnabledFor(logging.DEBUG):
51 _LOG.debug("DDL statements:\n%s", self._schema)
52 return self._db_schema
54 def __repr__(self) -> str:
55 return self._repr
57 def experiment( # pylint: disable=too-many-arguments
58 self,
59 *,
60 experiment_id: str,
61 trial_id: int,
62 root_env_config: str,
63 description: str,
64 tunables: TunableGroups,
65 opt_targets: Dict[str, Literal["min", "max"]],
66 ) -> Storage.Experiment:
67 return Experiment(
68 engine=self._engine,
69 schema=self._schema,
70 tunables=tunables,
71 experiment_id=experiment_id,
72 trial_id=trial_id,
73 root_env_config=root_env_config,
74 description=description,
75 opt_targets=opt_targets,
76 )
78 @property
79 def experiments(self) -> Dict[str, ExperimentData]:
80 # FIXME: this is somewhat expensive if only fetching a single Experiment.
81 # May need to expand the API or data structures to lazily fetch data and/or cache it.
82 with self._engine.connect() as conn:
83 cur_exp = conn.execute(
84 self._schema.experiment.select().order_by(
85 self._schema.experiment.c.exp_id.asc(),
86 )
87 )
88 return {
89 exp.exp_id: ExperimentSqlData(
90 engine=self._engine,
91 schema=self._schema,
92 experiment_id=exp.exp_id,
93 description=exp.description,
94 root_env_config=exp.root_env_config,
95 git_repo=exp.git_repo,
96 git_commit=exp.git_commit,
97 )
98 for exp in cur_exp.fetchall()
99 }