Coverage for mlos_bench/mlos_bench/storage/sql/schema.py: 95%
37 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""DB schema definition."""
7import logging
8from typing import Any, List
10from sqlalchemy import (
11 Column,
12 DateTime,
13 Dialect,
14 Engine,
15 Float,
16 ForeignKeyConstraint,
17 Integer,
18 MetaData,
19 PrimaryKeyConstraint,
20 Sequence,
21 String,
22 Table,
23 UniqueConstraint,
24 create_mock_engine,
25)
27_LOG = logging.getLogger(__name__)
30class _DDL:
31 """
32 A helper class to capture the DDL statements from SQLAlchemy.
34 It is used in `DbSchema.__str__()` method below.
35 """
37 def __init__(self, dialect: Dialect):
38 self._dialect = dialect
39 self.statements: List[str] = []
41 def __call__(self, sql: Any, *_args: Any, **_kwargs: Any) -> None:
42 self.statements.append(str(sql.compile(dialect=self._dialect)))
44 def __repr__(self) -> str:
45 res = ";\n".join(self.statements)
46 return res + ";" if res else ""
49class DbSchema:
50 """A class to define and create the DB schema."""
52 # This class is internal to SqlStorage and is mostly a struct
53 # for all DB tables, so it's ok to disable the warnings.
54 # pylint: disable=too-many-instance-attributes
56 # Common string column sizes.
57 _ID_LEN = 512
58 _PARAM_VALUE_LEN = 1024
59 _METRIC_VALUE_LEN = 255
60 _STATUS_LEN = 16
62 def __init__(self, engine: Engine):
63 """Declare the SQLAlchemy schema for the database."""
64 _LOG.info("Create the DB schema for: %s", engine)
65 self._engine = engine
66 # TODO: bind for automatic schema updates? (#649)
67 self._meta = MetaData()
69 self.experiment = Table(
70 "experiment",
71 self._meta,
72 Column("exp_id", String(self._ID_LEN), nullable=False),
73 Column("description", String(1024)),
74 Column("root_env_config", String(1024), nullable=False),
75 Column("git_repo", String(1024), nullable=False),
76 Column("git_commit", String(40), nullable=False),
77 PrimaryKeyConstraint("exp_id"),
78 )
80 self.objectives = Table(
81 "objectives",
82 self._meta,
83 Column("exp_id"),
84 Column("optimization_target", String(self._ID_LEN), nullable=False),
85 Column("optimization_direction", String(4), nullable=False),
86 # TODO: Note: weight is not fully supported yet as currently
87 # multi-objective is expected to explore each objective equally.
88 # Will need to adjust the insert and return values to support this
89 # eventually.
90 Column("weight", Float, nullable=True),
91 PrimaryKeyConstraint("exp_id", "optimization_target"),
92 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
93 )
95 # A workaround for SQLAlchemy issue with autoincrement in DuckDB:
96 if engine.dialect.name == "duckdb":
97 seq_config_id = Sequence("seq_config_id")
98 col_config_id = Column(
99 "config_id",
100 Integer,
101 seq_config_id,
102 server_default=seq_config_id.next_value(),
103 nullable=False,
104 primary_key=True,
105 )
106 else:
107 col_config_id = Column(
108 "config_id",
109 Integer,
110 nullable=False,
111 primary_key=True,
112 autoincrement=True,
113 )
115 self.config = Table(
116 "config",
117 self._meta,
118 col_config_id,
119 Column("config_hash", String(64), nullable=False, unique=True),
120 )
122 self.trial = Table(
123 "trial",
124 self._meta,
125 Column("exp_id", String(self._ID_LEN), nullable=False),
126 Column("trial_id", Integer, nullable=False),
127 Column("config_id", Integer, nullable=False),
128 Column("ts_start", DateTime, nullable=False),
129 Column("ts_end", DateTime),
130 # Should match the text IDs of `mlos_bench.environments.Status` enum:
131 Column("status", String(self._STATUS_LEN), nullable=False),
132 PrimaryKeyConstraint("exp_id", "trial_id"),
133 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]),
134 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
135 )
137 # Values of the tunable parameters of the experiment,
138 # fixed for a particular trial config.
139 self.config_param = Table(
140 "config_param",
141 self._meta,
142 Column("config_id", Integer, nullable=False),
143 Column("param_id", String(self._ID_LEN), nullable=False),
144 Column("param_value", String(self._PARAM_VALUE_LEN)),
145 PrimaryKeyConstraint("config_id", "param_id"),
146 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]),
147 )
149 # Values of additional non-tunable parameters of the trial,
150 # e.g., scheduled execution time, VM name / location, number of repeats, etc.
151 self.trial_param = Table(
152 "trial_param",
153 self._meta,
154 Column("exp_id", String(self._ID_LEN), nullable=False),
155 Column("trial_id", Integer, nullable=False),
156 Column("param_id", String(self._ID_LEN), nullable=False),
157 Column("param_value", String(self._PARAM_VALUE_LEN)),
158 PrimaryKeyConstraint("exp_id", "trial_id", "param_id"),
159 ForeignKeyConstraint(
160 ["exp_id", "trial_id"],
161 [self.trial.c.exp_id, self.trial.c.trial_id],
162 ),
163 )
165 self.trial_status = Table(
166 "trial_status",
167 self._meta,
168 Column("exp_id", String(self._ID_LEN), nullable=False),
169 Column("trial_id", Integer, nullable=False),
170 Column("ts", DateTime(timezone=True), nullable=False, default="now"),
171 Column("status", String(self._STATUS_LEN), nullable=False),
172 UniqueConstraint("exp_id", "trial_id", "ts"),
173 ForeignKeyConstraint(
174 ["exp_id", "trial_id"],
175 [self.trial.c.exp_id, self.trial.c.trial_id],
176 ),
177 )
179 self.trial_result = Table(
180 "trial_result",
181 self._meta,
182 Column("exp_id", String(self._ID_LEN), nullable=False),
183 Column("trial_id", Integer, nullable=False),
184 Column("metric_id", String(self._ID_LEN), nullable=False),
185 Column("metric_value", String(self._METRIC_VALUE_LEN)),
186 PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"),
187 ForeignKeyConstraint(
188 ["exp_id", "trial_id"],
189 [self.trial.c.exp_id, self.trial.c.trial_id],
190 ),
191 )
193 self.trial_telemetry = Table(
194 "trial_telemetry",
195 self._meta,
196 Column("exp_id", String(self._ID_LEN), nullable=False),
197 Column("trial_id", Integer, nullable=False),
198 Column("ts", DateTime(timezone=True), nullable=False, default="now"),
199 Column("metric_id", String(self._ID_LEN), nullable=False),
200 Column("metric_value", String(self._METRIC_VALUE_LEN)),
201 UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"),
202 ForeignKeyConstraint(
203 ["exp_id", "trial_id"],
204 [self.trial.c.exp_id, self.trial.c.trial_id],
205 ),
206 )
208 _LOG.debug("Schema: %s", self._meta)
210 def create(self) -> "DbSchema":
211 """Create the DB schema."""
212 _LOG.info("Create the DB schema")
213 self._meta.create_all(self._engine)
214 return self
216 def __repr__(self) -> str:
217 """
218 Produce a string with all SQL statements required to create the schema from
219 scratch in current SQL dialect.
221 That is, return a collection of CREATE TABLE statements and such.
222 NOTE: this method is quite heavy! We use it only once at startup
223 to log the schema, and if the logging level is set to DEBUG.
225 Returns
226 -------
227 sql : str
228 A multi-line string with SQL statements to create the DB schema from scratch.
229 """
230 ddl = _DDL(self._engine.dialect)
231 mock_engine = create_mock_engine(self._engine.url, executor=ddl)
232 self._meta.create_all(mock_engine, checkfirst=False)
233 return str(ddl)