Coverage for mlos_bench/mlos_bench/storage/sql/experiment.py: 92%
101 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5""":py:class:`.Storage.Experiment` interface implementation for saving and restoring
6the benchmark experiment data using `SQLAlchemy <https://sqlalchemy.org>`_ backend.
7"""
9import hashlib
10import logging
11from collections.abc import Iterator
12from datetime import datetime
13from typing import Any, Literal
15from pytz import UTC
16from sqlalchemy import Connection, CursorResult, Table, column, func, select
17from sqlalchemy.engine import Engine
19from mlos_bench.environments.status import Status
20from mlos_bench.storage.base_storage import Storage
21from mlos_bench.storage.sql.common import save_params
22from mlos_bench.storage.sql.schema import DbSchema
23from mlos_bench.storage.sql.trial import Trial
24from mlos_bench.tunables.tunable_groups import TunableGroups
25from mlos_bench.util import utcify_timestamp
27_LOG = logging.getLogger(__name__)
30class Experiment(Storage.Experiment):
31 """Logic for retrieving and storing the results of a single experiment."""
33 def __init__( # pylint: disable=too-many-arguments
34 self,
35 *,
36 engine: Engine,
37 schema: DbSchema,
38 tunables: TunableGroups,
39 experiment_id: str,
40 trial_id: int,
41 root_env_config: str,
42 description: str,
43 opt_targets: dict[str, Literal["min", "max"]],
44 ):
45 super().__init__(
46 tunables=tunables,
47 experiment_id=experiment_id,
48 trial_id=trial_id,
49 root_env_config=root_env_config,
50 description=description,
51 opt_targets=opt_targets,
52 )
53 self._engine = engine
54 self._schema = schema
56 def _setup(self) -> None:
57 super()._setup()
58 with self._engine.begin() as conn:
59 # Get git info and the last trial ID for the experiment.
60 # pylint: disable=not-callable
61 exp_info = conn.execute(
62 self._schema.experiment.select()
63 .with_only_columns(
64 self._schema.experiment.c.git_repo,
65 self._schema.experiment.c.git_commit,
66 self._schema.experiment.c.root_env_config,
67 func.max(self._schema.trial.c.trial_id).label("trial_id"),
68 )
69 .join(
70 self._schema.trial,
71 self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id,
72 isouter=True,
73 )
74 .where(
75 self._schema.experiment.c.exp_id == self._experiment_id,
76 )
77 .group_by(
78 self._schema.experiment.c.git_repo,
79 self._schema.experiment.c.git_commit,
80 self._schema.experiment.c.root_env_config,
81 )
82 ).fetchone()
83 if exp_info is None:
84 _LOG.info("Start new experiment: %s", self._experiment_id)
85 # It's a new experiment: create a record for it in the database.
86 conn.execute(
87 self._schema.experiment.insert().values(
88 exp_id=self._experiment_id,
89 description=self._description,
90 git_repo=self._git_repo,
91 git_commit=self._git_commit,
92 root_env_config=self._root_env_config,
93 )
94 )
95 conn.execute(
96 self._schema.objectives.insert().values(
97 [
98 {
99 "exp_id": self._experiment_id,
100 "optimization_target": opt_target,
101 "optimization_direction": opt_dir,
102 }
103 for (opt_target, opt_dir) in self.opt_targets.items()
104 ]
105 )
106 )
107 else:
108 if exp_info.trial_id is not None:
109 self._trial_id = exp_info.trial_id + 1
110 _LOG.info(
111 "Continue experiment: %s last trial: %s resume from: %d",
112 self._experiment_id,
113 exp_info.trial_id,
114 self._trial_id,
115 )
116 # TODO: Sanity check that certain critical configs (e.g.,
117 # objectives) haven't changed to be incompatible such that a new
118 # experiment should be started (possibly by prewarming with the
119 # previous one).
120 if exp_info.git_commit != self._git_commit:
121 _LOG.warning(
122 "Experiment %s git expected: %s %s",
123 self,
124 exp_info.git_repo,
125 exp_info.git_commit,
126 )
128 def merge(self, experiment_ids: list[str]) -> None:
129 _LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids)
130 raise NotImplementedError("TODO: Merging experiments not implemented yet.")
132 def load_tunable_config(self, config_id: int) -> dict[str, Any]:
133 with self._engine.connect() as conn:
134 return self._get_key_val(conn, self._schema.config_param, "param", config_id=config_id)
136 def load_telemetry(self, trial_id: int) -> list[tuple[datetime, str, Any]]:
137 with self._engine.connect() as conn:
138 cur_telemetry = conn.execute(
139 self._schema.trial_telemetry.select()
140 .where(
141 self._schema.trial_telemetry.c.exp_id == self._experiment_id,
142 self._schema.trial_telemetry.c.trial_id == trial_id,
143 )
144 .order_by(
145 self._schema.trial_telemetry.c.ts,
146 self._schema.trial_telemetry.c.metric_id,
147 )
148 )
149 # Not all storage backends store the original zone info.
150 # We try to ensure data is entered in UTC and augment it on return again here.
151 return [
152 (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value)
153 for row in cur_telemetry.fetchall()
154 ]
156 def load(
157 self,
158 last_trial_id: int = -1,
159 ) -> tuple[list[int], list[dict], list[dict[str, Any] | None], list[Status]]:
161 with self._engine.connect() as conn:
162 cur_trials = conn.execute(
163 self._schema.trial.select()
164 .with_only_columns(
165 self._schema.trial.c.trial_id,
166 self._schema.trial.c.config_id,
167 self._schema.trial.c.status,
168 )
169 .where(
170 self._schema.trial.c.exp_id == self._experiment_id,
171 self._schema.trial.c.trial_id > last_trial_id,
172 self._schema.trial.c.status.in_(
173 [
174 Status.SUCCEEDED.name,
175 Status.FAILED.name,
176 Status.TIMED_OUT.name,
177 ]
178 ),
179 )
180 .order_by(
181 self._schema.trial.c.trial_id.asc(),
182 )
183 )
185 trial_ids: list[int] = []
186 configs: list[dict[str, Any]] = []
187 scores: list[dict[str, Any] | None] = []
188 status: list[Status] = []
190 for trial in cur_trials.fetchall():
191 stat = Status[trial.status]
192 status.append(stat)
193 trial_ids.append(trial.trial_id)
194 configs.append(
195 self._get_key_val(
196 conn,
197 self._schema.config_param,
198 "param",
199 config_id=trial.config_id,
200 )
201 )
202 if stat.is_succeeded():
203 scores.append(
204 self._get_key_val(
205 conn,
206 self._schema.trial_result,
207 "metric",
208 exp_id=self._experiment_id,
209 trial_id=trial.trial_id,
210 )
211 )
212 else:
213 scores.append(None)
215 return (trial_ids, configs, scores, status)
217 @staticmethod
218 def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> dict[str, Any]:
219 """
220 Helper method to retrieve key-value pairs from the database.
222 (E.g., configurations, results, and telemetry).
223 """
224 cur_result: CursorResult[tuple[str, Any]] = conn.execute(
225 select(
226 column(f"{field}_id"),
227 column(f"{field}_value"),
228 )
229 .select_from(table)
230 .where(*[column(key) == val for (key, val) in kwargs.items()])
231 )
232 # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to
233 # avoid naming conflicts.
234 return dict(
235 row._tuple() for row in cur_result.fetchall() # pylint: disable=protected-access
236 )
238 def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]:
239 timestamp = utcify_timestamp(timestamp, origin="local")
240 _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp)
241 if running:
242 pending_status = [Status.PENDING.name, Status.READY.name, Status.RUNNING.name]
243 else:
244 pending_status = [Status.PENDING.name]
245 with self._engine.connect() as conn:
246 cur_trials = conn.execute(
247 self._schema.trial.select().where(
248 self._schema.trial.c.exp_id == self._experiment_id,
249 (
250 self._schema.trial.c.ts_start.is_(None)
251 | (self._schema.trial.c.ts_start <= timestamp)
252 ),
253 self._schema.trial.c.ts_end.is_(None),
254 self._schema.trial.c.status.in_(pending_status),
255 )
256 )
257 for trial in cur_trials.fetchall():
258 tunables = self._get_key_val(
259 conn,
260 self._schema.config_param,
261 "param",
262 config_id=trial.config_id,
263 )
264 config = self._get_key_val(
265 conn,
266 self._schema.trial_param,
267 "param",
268 exp_id=self._experiment_id,
269 trial_id=trial.trial_id,
270 )
271 yield Trial(
272 engine=self._engine,
273 schema=self._schema,
274 # Reset .is_updated flag after the assignment:
275 tunables=self._tunables.copy().assign(tunables).reset(),
276 experiment_id=self._experiment_id,
277 trial_id=trial.trial_id,
278 config_id=trial.config_id,
279 opt_targets=self._opt_targets,
280 config=config,
281 )
283 def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int:
284 """
285 Get the config ID for the given tunables.
287 If the config does not exist, create a new record for it.
288 """
289 config_hash = hashlib.sha256(str(tunables).encode("utf-8")).hexdigest()
290 cur_config = conn.execute(
291 self._schema.config.select().where(self._schema.config.c.config_hash == config_hash)
292 ).fetchone()
293 if cur_config is not None:
294 return int(cur_config.config_id) # mypy doesn't know it's always int
295 # Config not found, create a new one:
296 new_config_result = conn.execute(
297 self._schema.config.insert().values(config_hash=config_hash)
298 ).inserted_primary_key
299 assert new_config_result
300 config_id: int = new_config_result[0]
301 save_params(
302 conn,
303 self._schema.config_param,
304 {tunable.name: tunable.value for (tunable, _group) in tunables},
305 config_id=config_id,
306 )
307 return config_id
309 def _new_trial(
310 self,
311 tunables: TunableGroups,
312 ts_start: datetime | None = None,
313 config: dict[str, Any] | None = None,
314 ) -> Storage.Trial:
315 # MySQL can round microseconds into the future causing scheduler to skip trials.
316 # Truncate microseconds to avoid this issue.
317 ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local").replace(
318 microsecond=0
319 )
320 _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start)
321 with self._engine.begin() as conn:
322 try:
323 new_trial_status = Status.PENDING
324 config_id = self._get_config_id(conn, tunables)
325 conn.execute(
326 self._schema.trial.insert().values(
327 exp_id=self._experiment_id,
328 trial_id=self._trial_id,
329 config_id=config_id,
330 ts_start=ts_start,
331 status=new_trial_status.name,
332 )
333 )
335 # Note: config here is the framework config, not the target
336 # environment config (i.e., tunables).
337 if config is not None:
338 save_params(
339 conn,
340 self._schema.trial_param,
341 config,
342 exp_id=self._experiment_id,
343 trial_id=self._trial_id,
344 )
346 trial = Trial(
347 engine=self._engine,
348 schema=self._schema,
349 tunables=tunables,
350 experiment_id=self._experiment_id,
351 trial_id=self._trial_id,
352 config_id=config_id,
353 opt_targets=self._opt_targets,
354 config=config,
355 status=new_trial_status,
356 )
357 self._trial_id += 1
358 return trial
359 except Exception:
360 conn.rollback()
361 raise