Coverage for mlos_bench/mlos_bench/storage/sql/schema.py: 100%
69 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6DB schema definition for the :py:class:`~mlos_bench.storage.sql.storage.SqlStorage`
7backend.
9Notes
10-----
11The SQL statements are generated by SQLAlchemy, but can be obtained using
12``repr`` or ``str`` (e.g., via ``print()``) on this object.
13The ``mlos_bench`` CLI will do this automatically if the logging level is set to
14``DEBUG``.
16Also see the `mlos_bench CLI usage <../../../../../mlos_bench.run.usage.html>`__ for
17details on how to invoke only the schema creation/update routines.
18"""
20import logging
21from importlib.resources import files
22from typing import Any
24from alembic import command, config
25from sqlalchemy import (
26 Column,
27 Connection,
28 DateTime,
29 Dialect,
30 Float,
31 ForeignKeyConstraint,
32 Integer,
33 MetaData,
34 PrimaryKeyConstraint,
35 Sequence,
36 String,
37 Table,
38 UniqueConstraint,
39 create_mock_engine,
40 inspect,
41)
42from sqlalchemy.engine import Engine
44from mlos_bench.util import path_join
46_LOG = logging.getLogger(__name__)
49class _DDL:
50 """
51 A helper class to capture the DDL statements from SQLAlchemy.
53 It is used in `DbSchema.__str__()` method below.
54 """
56 def __init__(self, dialect: Dialect):
57 self._dialect = dialect
58 self.statements: list[str] = []
60 def __call__(self, sql: Any, *_args: Any, **_kwargs: Any) -> None:
61 self.statements.append(str(sql.compile(dialect=self._dialect)))
63 def __repr__(self) -> str:
64 res = ";\n".join(self.statements)
65 return res + ";" if res else ""
68class DbSchema:
69 """A class to define and create the DB schema."""
71 # This class is internal to SqlStorage and is mostly a struct
72 # for all DB tables, so it's ok to disable the warnings.
73 # pylint: disable=too-many-instance-attributes
75 # Common string column sizes.
76 _ID_LEN = 512
77 _PARAM_VALUE_LEN = 1024
78 _METRIC_VALUE_LEN = 255
79 _STATUS_LEN = 16
81 def __init__(self, engine: Engine | None):
82 """
83 Declare the SQLAlchemy schema for the database.
85 Parameters
86 ----------
87 engine : sqlalchemy.engine.Engine | None
88 The SQLAlchemy engine to use for the DB schema.
89 Listed as optional for `alembic <https://alembic.sqlalchemy.org>`_
90 schema migration purposes so we can reference it inside it's ``env.py``
91 config file for :attr:`~meta` data inspection, but won't generally be
92 functional without one.
93 """
94 _LOG.info("Create the DB schema for: %s", engine)
95 self._engine = engine
96 self._meta = MetaData()
98 self.experiment = Table(
99 "experiment",
100 self._meta,
101 Column("exp_id", String(self._ID_LEN), nullable=False),
102 Column("description", String(1024)),
103 Column("root_env_config", String(1024), nullable=False),
104 Column("git_repo", String(1024), nullable=False),
105 Column("git_commit", String(40), nullable=False),
106 # For backwards compatibility, we allow NULL for ts_start.
107 Column("ts_start", DateTime),
108 Column("ts_end", DateTime),
109 # Should match the text IDs of `mlos_bench.environments.Status` enum:
110 # For backwards compatibility, we allow NULL for status.
111 Column("status", String(self._STATUS_LEN)),
112 # There may be more than one mlos_benchd_service running on different hosts.
113 # This column stores the host/container name of the driver that
114 # picked up the experiment.
115 # They should use a transaction to update it to their own hostname when
116 # they start if and only if its NULL.
117 Column("driver_name", String(40), comment="Driver Host/Container Name"),
118 Column("driver_pid", Integer, comment="Driver Process ID"),
119 PrimaryKeyConstraint("exp_id"),
120 )
121 """The Table storing
122 :py:class:`~mlos_bench.storage.base_experiment_data.ExperimentData` info.
123 """
125 self.objectives = Table(
126 "objectives",
127 self._meta,
128 Column("exp_id"),
129 Column("optimization_target", String(self._ID_LEN), nullable=False),
130 Column("optimization_direction", String(4), nullable=False),
131 # TODO: Note: weight is not fully supported yet as currently
132 # multi-objective is expected to explore each objective equally.
133 # Will need to adjust the insert and return values to support this
134 # eventually.
135 Column("weight", Float, nullable=True),
136 PrimaryKeyConstraint("exp_id", "optimization_target"),
137 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
138 )
139 """The Table storing
140 :py:class:`~mlos_bench.storage.base_storage.Storage.Experiment` optimization
141 objectives info.
142 """
144 # A workaround for SQLAlchemy issue with autoincrement in DuckDB:
145 if engine and engine.dialect.name == "duckdb":
146 seq_config_id = Sequence("seq_config_id")
147 col_config_id = Column(
148 "config_id",
149 Integer,
150 seq_config_id,
151 server_default=seq_config_id.next_value(),
152 nullable=False,
153 primary_key=True,
154 )
155 else:
156 col_config_id = Column(
157 "config_id",
158 Integer,
159 nullable=False,
160 primary_key=True,
161 autoincrement=True,
162 )
164 self.config = Table(
165 "config",
166 self._meta,
167 col_config_id,
168 Column("config_hash", String(64), nullable=False, unique=True),
169 )
170 """The Table storing
171 :py:class:`~mlos_bench.storage.base_tunable_config_data.TunableConfigData`
172 info.
173 """
175 self.trial = Table(
176 "trial",
177 self._meta,
178 Column("exp_id", String(self._ID_LEN), nullable=False),
179 Column("trial_id", Integer, nullable=False),
180 Column("config_id", Integer, nullable=False),
181 Column("trial_runner_id", Integer, nullable=True, default=None),
182 Column("ts_start", DateTime, nullable=False),
183 Column("ts_end", DateTime),
184 # Should match the text IDs of `mlos_bench.environments.Status` enum:
185 Column("status", String(self._STATUS_LEN), nullable=False),
186 PrimaryKeyConstraint("exp_id", "trial_id"),
187 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
188 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
189 )
190 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData`
191 info.
192 """
194 # Values of the tunable parameters of the experiment,
195 # fixed for a particular trial config.
196 self.config_param = Table(
197 "config_param",
198 self._meta,
199 Column("config_id", Integer, nullable=False),
200 Column("param_id", String(self._ID_LEN), nullable=False),
201 Column("param_value", String(self._PARAM_VALUE_LEN)),
202 PrimaryKeyConstraint("config_id", "param_id"),
203 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
204 )
205 """The Table storing
206 :py:class:`~mlos_bench.storage.base_tunable_config_data.TunableConfigData`
207 info.
208 """
210 # Values of additional non-tunable parameters of the trial,
211 # e.g., scheduled execution time, VM name / location, number of repeats, etc.
212 self.trial_param = Table(
213 "trial_param",
214 self._meta,
215 Column("exp_id", String(self._ID_LEN), nullable=False),
216 Column("trial_id", Integer, nullable=False),
217 Column("param_id", String(self._ID_LEN), nullable=False),
218 Column("param_value", String(self._PARAM_VALUE_LEN)),
219 PrimaryKeyConstraint("exp_id", "trial_id", "param_id"),
220 ForeignKeyConstraint(
221 ["exp_id", "trial_id"],
222 [self.trial.c.exp_id, self.trial.c.trial_id],
223 ),
224 )
225 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData`
226 :py:attr:`metadata <mlos_bench.storage.base_trial_data.TrialData.metadata_dict>`
227 info.
228 """
230 self.trial_status = Table(
231 "trial_status",
232 self._meta,
233 Column("exp_id", String(self._ID_LEN), nullable=False),
234 Column("trial_id", Integer, nullable=False),
235 Column("ts", DateTime(timezone=True), nullable=False, default="now"),
236 Column("status", String(self._STATUS_LEN), nullable=False),
237 UniqueConstraint("exp_id", "trial_id", "ts"),
238 ForeignKeyConstraint(
239 ["exp_id", "trial_id"],
240 [self.trial.c.exp_id, self.trial.c.trial_id],
241 ),
242 )
243 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData`
244 :py:class:`~mlos_bench.environments.status.Status` info.
245 """
247 self.trial_result = Table(
248 "trial_result",
249 self._meta,
250 Column("exp_id", String(self._ID_LEN), nullable=False),
251 Column("trial_id", Integer, nullable=False),
252 Column("metric_id", String(self._ID_LEN), nullable=False),
253 Column("metric_value", String(self._METRIC_VALUE_LEN)),
254 PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"),
255 ForeignKeyConstraint(
256 ["exp_id", "trial_id"],
257 [self.trial.c.exp_id, self.trial.c.trial_id],
258 ),
259 )
260 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData`
261 :py:attr:`results <mlos_bench.storage.base_trial_data.TrialData.results_dict>`
262 info.
263 """
265 self.trial_telemetry = Table(
266 "trial_telemetry",
267 self._meta,
268 Column("exp_id", String(self._ID_LEN), nullable=False),
269 Column("trial_id", Integer, nullable=False),
270 Column("ts", DateTime(timezone=True), nullable=False, default="now"),
271 Column("metric_id", String(self._ID_LEN), nullable=False),
272 Column("metric_value", String(self._METRIC_VALUE_LEN)),
273 UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"),
274 ForeignKeyConstraint(
275 ["exp_id", "trial_id"],
276 [self.trial.c.exp_id, self.trial.c.trial_id],
277 ),
278 )
279 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData`
280 :py:attr:`telemetry <mlos_bench.storage.base_trial_data.TrialData.telemetry_df>`
281 info.
282 """
284 _LOG.debug("Schema: %s", self._meta)
286 @property
287 def meta(self) -> MetaData:
288 """Return the SQLAlchemy MetaData object."""
289 return self._meta
291 @staticmethod
292 def _get_alembic_cfg(conn: Connection) -> config.Config:
293 alembic_cfg = config.Config(
294 path_join(str(files("mlos_bench.storage.sql")), "alembic.ini", abs_path=True)
295 )
296 alembic_cfg.attributes["connection"] = conn
297 return alembic_cfg
299 def create(self) -> "DbSchema":
300 """Create the DB schema."""
301 _LOG.info("Create the DB schema")
302 assert self._engine
303 self._meta.create_all(self._engine)
304 with self._engine.begin() as conn:
305 # If the trial table has the trial_runner_id column but no
306 # "alembic_version" table, then the schema is up to date as of initial
307 # create and we should mark it as such to avoid trying to run the
308 # (non-idempotent) upgrade scripts.
309 # Otherwise, either we already have an alembic_version table and can
310 # safely run the necessary upgrades or we are missing the
311 # trial_runner_id column (the first to introduce schema updates) and
312 # should run the upgrades.
313 if any(
314 column["name"] == "trial_runner_id"
315 for column in inspect(conn).get_columns(self.trial.name)
316 ) and not inspect(conn).has_table("alembic_version"):
317 # Mark the schema as up to date.
318 alembic_cfg = self._get_alembic_cfg(conn)
319 command.stamp(alembic_cfg, "heads")
320 # command.current(alembic_cfg)
321 return self
323 def update(self) -> "DbSchema":
324 """
325 Updates the DB schema to the latest version.
327 Notes
328 -----
329 Also see the `mlos_bench CLI usage <../../../../../mlos_bench.run.usage.html>`__
330 for details on how to invoke only the schema creation/update routines.
331 """
332 assert self._engine
333 with self._engine.connect() as conn:
334 alembic_cfg = self._get_alembic_cfg(conn)
335 command.upgrade(alembic_cfg, "head")
336 return self
338 def __repr__(self) -> str:
339 """
340 Produce a string with all SQL statements required to create the schema from
341 scratch in current SQL dialect.
343 That is, return a collection of CREATE TABLE statements and such.
344 NOTE: this method is quite heavy! We use it only once at startup
345 to log the schema, and if the logging level is set to DEBUG.
347 Returns
348 -------
349 sql : str
350 A multi-line string with SQL statements to create the DB schema from scratch.
351 """
352 assert self._engine
353 ddl = _DDL(self._engine.dialect)
354 mock_engine = create_mock_engine(self._engine.url, executor=ddl)
355 self._meta.create_all(mock_engine, checkfirst=False)
356 return str(ddl)