Coverage for mlos_bench/mlos_bench/storage/sql/schema.py: 100%

69 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6DB schema definition for the :py:class:`~mlos_bench.storage.sql.storage.SqlStorage` 

7backend. 

8 

9Notes 

10----- 

11The SQL statements are generated by SQLAlchemy, but can be obtained using 

12``repr`` or ``str`` (e.g., via ``print()``) on this object. 

13The ``mlos_bench`` CLI will do this automatically if the logging level is set to 

14``DEBUG``. 

15 

16Also see the `mlos_bench CLI usage <../../../../../mlos_bench.run.usage.html>`__ for 

17details on how to invoke only the schema creation/update routines. 

18""" 

19 

20import logging 

21from importlib.resources import files 

22from typing import Any 

23 

24from alembic import command, config 

25from sqlalchemy import ( 

26 Column, 

27 Connection, 

28 DateTime, 

29 Dialect, 

30 Float, 

31 ForeignKeyConstraint, 

32 Integer, 

33 MetaData, 

34 PrimaryKeyConstraint, 

35 Sequence, 

36 String, 

37 Table, 

38 UniqueConstraint, 

39 create_mock_engine, 

40 inspect, 

41) 

42from sqlalchemy.engine import Engine 

43 

44from mlos_bench.util import path_join 

45 

46_LOG = logging.getLogger(__name__) 

47 

48 

49class _DDL: 

50 """ 

51 A helper class to capture the DDL statements from SQLAlchemy. 

52 

53 It is used in `DbSchema.__str__()` method below. 

54 """ 

55 

56 def __init__(self, dialect: Dialect): 

57 self._dialect = dialect 

58 self.statements: list[str] = [] 

59 

60 def __call__(self, sql: Any, *_args: Any, **_kwargs: Any) -> None: 

61 self.statements.append(str(sql.compile(dialect=self._dialect))) 

62 

63 def __repr__(self) -> str: 

64 res = ";\n".join(self.statements) 

65 return res + ";" if res else "" 

66 

67 

68class DbSchema: 

69 """A class to define and create the DB schema.""" 

70 

71 # This class is internal to SqlStorage and is mostly a struct 

72 # for all DB tables, so it's ok to disable the warnings. 

73 # pylint: disable=too-many-instance-attributes 

74 

75 # Common string column sizes. 

76 _ID_LEN = 512 

77 _PARAM_VALUE_LEN = 1024 

78 _METRIC_VALUE_LEN = 255 

79 _STATUS_LEN = 16 

80 

81 def __init__(self, engine: Engine | None): 

82 """ 

83 Declare the SQLAlchemy schema for the database. 

84 

85 Parameters 

86 ---------- 

87 engine : sqlalchemy.engine.Engine | None 

88 The SQLAlchemy engine to use for the DB schema. 

89 Listed as optional for `alembic <https://alembic.sqlalchemy.org>`_ 

90 schema migration purposes so we can reference it inside it's ``env.py`` 

91 config file for :attr:`~meta` data inspection, but won't generally be 

92 functional without one. 

93 """ 

94 _LOG.info("Create the DB schema for: %s", engine) 

95 self._engine = engine 

96 self._meta = MetaData() 

97 

98 self.experiment = Table( 

99 "experiment", 

100 self._meta, 

101 Column("exp_id", String(self._ID_LEN), nullable=False), 

102 Column("description", String(1024)), 

103 Column("root_env_config", String(1024), nullable=False), 

104 Column("git_repo", String(1024), nullable=False), 

105 Column("git_commit", String(40), nullable=False), 

106 # For backwards compatibility, we allow NULL for ts_start. 

107 Column("ts_start", DateTime), 

108 Column("ts_end", DateTime), 

109 # Should match the text IDs of `mlos_bench.environments.Status` enum: 

110 # For backwards compatibility, we allow NULL for status. 

111 Column("status", String(self._STATUS_LEN)), 

112 # There may be more than one mlos_benchd_service running on different hosts. 

113 # This column stores the host/container name of the driver that 

114 # picked up the experiment. 

115 # They should use a transaction to update it to their own hostname when 

116 # they start if and only if its NULL. 

117 Column("driver_name", String(40), comment="Driver Host/Container Name"), 

118 Column("driver_pid", Integer, comment="Driver Process ID"), 

119 PrimaryKeyConstraint("exp_id"), 

120 ) 

121 """The Table storing 

122 :py:class:`~mlos_bench.storage.base_experiment_data.ExperimentData` info. 

123 """ 

124 

125 self.objectives = Table( 

126 "objectives", 

127 self._meta, 

128 Column("exp_id"), 

129 Column("optimization_target", String(self._ID_LEN), nullable=False), 

130 Column("optimization_direction", String(4), nullable=False), 

131 # TODO: Note: weight is not fully supported yet as currently 

132 # multi-objective is expected to explore each objective equally. 

133 # Will need to adjust the insert and return values to support this 

134 # eventually. 

135 Column("weight", Float, nullable=True), 

136 PrimaryKeyConstraint("exp_id", "optimization_target"), 

137 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), 

138 ) 

139 """The Table storing 

140 :py:class:`~mlos_bench.storage.base_storage.Storage.Experiment` optimization 

141 objectives info. 

142 """ 

143 

144 # A workaround for SQLAlchemy issue with autoincrement in DuckDB: 

145 if engine and engine.dialect.name == "duckdb": 

146 seq_config_id = Sequence("seq_config_id") 

147 col_config_id = Column( 

148 "config_id", 

149 Integer, 

150 seq_config_id, 

151 server_default=seq_config_id.next_value(), 

152 nullable=False, 

153 primary_key=True, 

154 ) 

155 else: 

156 col_config_id = Column( 

157 "config_id", 

158 Integer, 

159 nullable=False, 

160 primary_key=True, 

161 autoincrement=True, 

162 ) 

163 

164 self.config = Table( 

165 "config", 

166 self._meta, 

167 col_config_id, 

168 Column("config_hash", String(64), nullable=False, unique=True), 

169 ) 

170 """The Table storing 

171 :py:class:`~mlos_bench.storage.base_tunable_config_data.TunableConfigData` 

172 info. 

173 """ 

174 

175 self.trial = Table( 

176 "trial", 

177 self._meta, 

178 Column("exp_id", String(self._ID_LEN), nullable=False), 

179 Column("trial_id", Integer, nullable=False), 

180 Column("config_id", Integer, nullable=False), 

181 Column("trial_runner_id", Integer, nullable=True, default=None), 

182 Column("ts_start", DateTime, nullable=False), 

183 Column("ts_end", DateTime), 

184 # Should match the text IDs of `mlos_bench.environments.Status` enum: 

185 Column("status", String(self._STATUS_LEN), nullable=False), 

186 PrimaryKeyConstraint("exp_id", "trial_id"), 

187 ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), 

188 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), 

189 ) 

190 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData` 

191 info. 

192 """ 

193 

194 # Values of the tunable parameters of the experiment, 

195 # fixed for a particular trial config. 

196 self.config_param = Table( 

197 "config_param", 

198 self._meta, 

199 Column("config_id", Integer, nullable=False), 

200 Column("param_id", String(self._ID_LEN), nullable=False), 

201 Column("param_value", String(self._PARAM_VALUE_LEN)), 

202 PrimaryKeyConstraint("config_id", "param_id"), 

203 ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), 

204 ) 

205 """The Table storing 

206 :py:class:`~mlos_bench.storage.base_tunable_config_data.TunableConfigData` 

207 info. 

208 """ 

209 

210 # Values of additional non-tunable parameters of the trial, 

211 # e.g., scheduled execution time, VM name / location, number of repeats, etc. 

212 self.trial_param = Table( 

213 "trial_param", 

214 self._meta, 

215 Column("exp_id", String(self._ID_LEN), nullable=False), 

216 Column("trial_id", Integer, nullable=False), 

217 Column("param_id", String(self._ID_LEN), nullable=False), 

218 Column("param_value", String(self._PARAM_VALUE_LEN)), 

219 PrimaryKeyConstraint("exp_id", "trial_id", "param_id"), 

220 ForeignKeyConstraint( 

221 ["exp_id", "trial_id"], 

222 [self.trial.c.exp_id, self.trial.c.trial_id], 

223 ), 

224 ) 

225 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData` 

226 :py:attr:`metadata <mlos_bench.storage.base_trial_data.TrialData.metadata_dict>` 

227 info. 

228 """ 

229 

230 self.trial_status = Table( 

231 "trial_status", 

232 self._meta, 

233 Column("exp_id", String(self._ID_LEN), nullable=False), 

234 Column("trial_id", Integer, nullable=False), 

235 Column("ts", DateTime(timezone=True), nullable=False, default="now"), 

236 Column("status", String(self._STATUS_LEN), nullable=False), 

237 UniqueConstraint("exp_id", "trial_id", "ts"), 

238 ForeignKeyConstraint( 

239 ["exp_id", "trial_id"], 

240 [self.trial.c.exp_id, self.trial.c.trial_id], 

241 ), 

242 ) 

243 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData` 

244 :py:class:`~mlos_bench.environments.status.Status` info. 

245 """ 

246 

247 self.trial_result = Table( 

248 "trial_result", 

249 self._meta, 

250 Column("exp_id", String(self._ID_LEN), nullable=False), 

251 Column("trial_id", Integer, nullable=False), 

252 Column("metric_id", String(self._ID_LEN), nullable=False), 

253 Column("metric_value", String(self._METRIC_VALUE_LEN)), 

254 PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"), 

255 ForeignKeyConstraint( 

256 ["exp_id", "trial_id"], 

257 [self.trial.c.exp_id, self.trial.c.trial_id], 

258 ), 

259 ) 

260 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData` 

261 :py:attr:`results <mlos_bench.storage.base_trial_data.TrialData.results_dict>` 

262 info. 

263 """ 

264 

265 self.trial_telemetry = Table( 

266 "trial_telemetry", 

267 self._meta, 

268 Column("exp_id", String(self._ID_LEN), nullable=False), 

269 Column("trial_id", Integer, nullable=False), 

270 Column("ts", DateTime(timezone=True), nullable=False, default="now"), 

271 Column("metric_id", String(self._ID_LEN), nullable=False), 

272 Column("metric_value", String(self._METRIC_VALUE_LEN)), 

273 UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"), 

274 ForeignKeyConstraint( 

275 ["exp_id", "trial_id"], 

276 [self.trial.c.exp_id, self.trial.c.trial_id], 

277 ), 

278 ) 

279 """The Table storing :py:class:`~mlos_bench.storage.base_trial_data.TrialData` 

280 :py:attr:`telemetry <mlos_bench.storage.base_trial_data.TrialData.telemetry_df>` 

281 info. 

282 """ 

283 

284 _LOG.debug("Schema: %s", self._meta) 

285 

286 @property 

287 def meta(self) -> MetaData: 

288 """Return the SQLAlchemy MetaData object.""" 

289 return self._meta 

290 

291 @staticmethod 

292 def _get_alembic_cfg(conn: Connection) -> config.Config: 

293 alembic_cfg = config.Config( 

294 path_join(str(files("mlos_bench.storage.sql")), "alembic.ini", abs_path=True) 

295 ) 

296 alembic_cfg.attributes["connection"] = conn 

297 return alembic_cfg 

298 

299 def create(self) -> "DbSchema": 

300 """Create the DB schema.""" 

301 _LOG.info("Create the DB schema") 

302 assert self._engine 

303 self._meta.create_all(self._engine) 

304 with self._engine.begin() as conn: 

305 # If the trial table has the trial_runner_id column but no 

306 # "alembic_version" table, then the schema is up to date as of initial 

307 # create and we should mark it as such to avoid trying to run the 

308 # (non-idempotent) upgrade scripts. 

309 # Otherwise, either we already have an alembic_version table and can 

310 # safely run the necessary upgrades or we are missing the 

311 # trial_runner_id column (the first to introduce schema updates) and 

312 # should run the upgrades. 

313 if any( 

314 column["name"] == "trial_runner_id" 

315 for column in inspect(conn).get_columns(self.trial.name) 

316 ) and not inspect(conn).has_table("alembic_version"): 

317 # Mark the schema as up to date. 

318 alembic_cfg = self._get_alembic_cfg(conn) 

319 command.stamp(alembic_cfg, "heads") 

320 # command.current(alembic_cfg) 

321 return self 

322 

323 def update(self) -> "DbSchema": 

324 """ 

325 Updates the DB schema to the latest version. 

326 

327 Notes 

328 ----- 

329 Also see the `mlos_bench CLI usage <../../../../../mlos_bench.run.usage.html>`__ 

330 for details on how to invoke only the schema creation/update routines. 

331 """ 

332 assert self._engine 

333 with self._engine.connect() as conn: 

334 alembic_cfg = self._get_alembic_cfg(conn) 

335 command.upgrade(alembic_cfg, "head") 

336 return self 

337 

338 def __repr__(self) -> str: 

339 """ 

340 Produce a string with all SQL statements required to create the schema from 

341 scratch in current SQL dialect. 

342 

343 That is, return a collection of CREATE TABLE statements and such. 

344 NOTE: this method is quite heavy! We use it only once at startup 

345 to log the schema, and if the logging level is set to DEBUG. 

346 

347 Returns 

348 ------- 

349 sql : str 

350 A multi-line string with SQL statements to create the DB schema from scratch. 

351 """ 

352 assert self._engine 

353 ddl = _DDL(self._engine.dialect) 

354 mock_engine = create_mock_engine(self._engine.url, executor=ddl) 

355 self._meta.create_all(mock_engine, checkfirst=False) 

356 return str(ddl)