Coverage for mlos_bench/mlos_bench/storage/sql/storage.py: 100%
46 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Saving and restoring the benchmark data in SQL database."""
7import logging
8from typing import Literal
10from sqlalchemy import URL, create_engine
12from mlos_bench.services.base_service import Service
13from mlos_bench.storage.base_experiment_data import ExperimentData
14from mlos_bench.storage.base_storage import Storage
15from mlos_bench.storage.sql.experiment import Experiment
16from mlos_bench.storage.sql.experiment_data import ExperimentSqlData
17from mlos_bench.storage.sql.schema import DbSchema
18from mlos_bench.tunables.tunable_groups import TunableGroups
20_LOG = logging.getLogger(__name__)
23class SqlStorage(Storage):
24 """An implementation of the :py:class:`~.Storage` interface using SQLAlchemy
25 backend.
26 """
28 def __init__(
29 self,
30 config: dict,
31 global_config: dict | None = None,
32 service: Service | None = None,
33 ):
34 super().__init__(config, global_config, service)
35 lazy_schema_create = self._config.pop("lazy_schema_create", False)
36 self._log_sql = self._config.pop("log_sql", False)
37 self._url = URL.create(**self._config)
38 self._repr = f"{self._url.get_backend_name()}:{self._url.database}"
39 _LOG.info("Connect to the database: %s", self)
40 self._engine = create_engine(self._url, echo=self._log_sql)
41 self._db_schema = DbSchema(self._engine)
42 self._schema_created = False
43 self._schema_updated = False
44 if not lazy_schema_create:
45 assert self._schema
46 self.update_schema()
47 else:
48 _LOG.info("Using lazy schema create for database: %s", self)
50 @property
51 def _schema(self) -> DbSchema:
52 """Lazily create schema upon first access."""
53 if not self._schema_created:
54 self._db_schema.create()
55 self._schema_created = True
56 if _LOG.isEnabledFor(logging.DEBUG):
57 _LOG.debug("DDL statements:\n%s", self._db_schema)
58 return self._db_schema
60 def update_schema(self) -> None:
61 """Update the database schema."""
62 if not self._schema_updated:
63 self._schema.update()
64 self._schema_updated = True
66 def __repr__(self) -> str:
67 return self._repr
69 def experiment( # pylint: disable=too-many-arguments
70 self,
71 *,
72 experiment_id: str,
73 trial_id: int,
74 root_env_config: str,
75 description: str,
76 tunables: TunableGroups,
77 opt_targets: dict[str, Literal["min", "max"]],
78 ) -> Storage.Experiment:
79 return Experiment(
80 engine=self._engine,
81 schema=self._schema,
82 tunables=tunables,
83 experiment_id=experiment_id,
84 trial_id=trial_id,
85 root_env_config=root_env_config,
86 description=description,
87 opt_targets=opt_targets,
88 )
90 @property
91 def experiments(self) -> dict[str, ExperimentData]:
92 # FIXME: this is somewhat expensive if only fetching a single Experiment.
93 # May need to expand the API or data structures to lazily fetch data and/or cache it.
94 with self._engine.connect() as conn:
95 cur_exp = conn.execute(
96 self._schema.experiment.select().order_by(
97 self._schema.experiment.c.exp_id.asc(),
98 )
99 )
100 return {
101 exp.exp_id: ExperimentSqlData(
102 engine=self._engine,
103 schema=self._schema,
104 experiment_id=exp.exp_id,
105 description=exp.description,
106 root_env_config=exp.root_env_config,
107 git_repo=exp.git_repo,
108 git_commit=exp.git_commit,
109 )
110 for exp in cur_exp.fetchall()
111 }