Coverage for mlos_bench/mlos_bench/storage/sql/experiment.py: 89%
100 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Saving and restoring the benchmark data using SQLAlchemy."""
7import hashlib
8import logging
9from datetime import datetime
10from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple
12from pytz import UTC
13from sqlalchemy import Connection, CursorResult, Engine, Table, column, func, select
15from mlos_bench.environments.status import Status
16from mlos_bench.storage.base_storage import Storage
17from mlos_bench.storage.sql.schema import DbSchema
18from mlos_bench.storage.sql.trial import Trial
19from mlos_bench.tunables.tunable_groups import TunableGroups
20from mlos_bench.util import nullable, utcify_timestamp
22_LOG = logging.getLogger(__name__)
25class Experiment(Storage.Experiment):
26 """Logic for retrieving and storing the results of a single experiment."""
28 def __init__( # pylint: disable=too-many-arguments
29 self,
30 *,
31 engine: Engine,
32 schema: DbSchema,
33 tunables: TunableGroups,
34 experiment_id: str,
35 trial_id: int,
36 root_env_config: str,
37 description: str,
38 opt_targets: Dict[str, Literal["min", "max"]],
39 ):
40 super().__init__(
41 tunables=tunables,
42 experiment_id=experiment_id,
43 trial_id=trial_id,
44 root_env_config=root_env_config,
45 description=description,
46 opt_targets=opt_targets,
47 )
48 self._engine = engine
49 self._schema = schema
51 def _setup(self) -> None:
52 super()._setup()
53 with self._engine.begin() as conn:
54 # Get git info and the last trial ID for the experiment.
55 # pylint: disable=not-callable
56 exp_info = conn.execute(
57 self._schema.experiment.select()
58 .with_only_columns(
59 self._schema.experiment.c.git_repo,
60 self._schema.experiment.c.git_commit,
61 self._schema.experiment.c.root_env_config,
62 func.max(self._schema.trial.c.trial_id).label("trial_id"),
63 )
64 .join(
65 self._schema.trial,
66 self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id,
67 isouter=True,
68 )
69 .where(
70 self._schema.experiment.c.exp_id == self._experiment_id,
71 )
72 .group_by(
73 self._schema.experiment.c.git_repo,
74 self._schema.experiment.c.git_commit,
75 self._schema.experiment.c.root_env_config,
76 )
77 ).fetchone()
78 if exp_info is None:
79 _LOG.info("Start new experiment: %s", self._experiment_id)
80 # It's a new experiment: create a record for it in the database.
81 conn.execute(
82 self._schema.experiment.insert().values(
83 exp_id=self._experiment_id,
84 description=self._description,
85 git_repo=self._git_repo,
86 git_commit=self._git_commit,
87 root_env_config=self._root_env_config,
88 )
89 )
90 conn.execute(
91 self._schema.objectives.insert().values(
92 [
93 {
94 "exp_id": self._experiment_id,
95 "optimization_target": opt_target,
96 "optimization_direction": opt_dir,
97 }
98 for (opt_target, opt_dir) in self.opt_targets.items()
99 ]
100 )
101 )
102 else:
103 if exp_info.trial_id is not None:
104 self._trial_id = exp_info.trial_id + 1
105 _LOG.info(
106 "Continue experiment: %s last trial: %s resume from: %d",
107 self._experiment_id,
108 exp_info.trial_id,
109 self._trial_id,
110 )
111 # TODO: Sanity check that certain critical configs (e.g.,
112 # objectives) haven't changed to be incompatible such that a new
113 # experiment should be started (possibly by prewarming with the
114 # previous one).
115 if exp_info.git_commit != self._git_commit:
116 _LOG.warning(
117 "Experiment %s git expected: %s %s",
118 self,
119 exp_info.git_repo,
120 exp_info.git_commit,
121 )
123 def merge(self, experiment_ids: List[str]) -> None:
124 _LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids)
125 raise NotImplementedError("TODO")
127 def load_tunable_config(self, config_id: int) -> Dict[str, Any]:
128 with self._engine.connect() as conn:
129 return self._get_key_val(conn, self._schema.config_param, "param", config_id=config_id)
131 def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]:
132 with self._engine.connect() as conn:
133 cur_telemetry = conn.execute(
134 self._schema.trial_telemetry.select()
135 .where(
136 self._schema.trial_telemetry.c.exp_id == self._experiment_id,
137 self._schema.trial_telemetry.c.trial_id == trial_id,
138 )
139 .order_by(
140 self._schema.trial_telemetry.c.ts,
141 self._schema.trial_telemetry.c.metric_id,
142 )
143 )
144 # Not all storage backends store the original zone info.
145 # We try to ensure data is entered in UTC and augment it on return again here.
146 return [
147 (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value)
148 for row in cur_telemetry.fetchall()
149 ]
151 def load(
152 self,
153 last_trial_id: int = -1,
154 ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]:
156 with self._engine.connect() as conn:
157 cur_trials = conn.execute(
158 self._schema.trial.select()
159 .with_only_columns(
160 self._schema.trial.c.trial_id,
161 self._schema.trial.c.config_id,
162 self._schema.trial.c.status,
163 )
164 .where(
165 self._schema.trial.c.exp_id == self._experiment_id,
166 self._schema.trial.c.trial_id > last_trial_id,
167 self._schema.trial.c.status.in_(["SUCCEEDED", "FAILED", "TIMED_OUT"]),
168 )
169 .order_by(
170 self._schema.trial.c.trial_id.asc(),
171 )
172 )
174 trial_ids: List[int] = []
175 configs: List[Dict[str, Any]] = []
176 scores: List[Optional[Dict[str, Any]]] = []
177 status: List[Status] = []
179 for trial in cur_trials.fetchall():
180 stat = Status[trial.status]
181 status.append(stat)
182 trial_ids.append(trial.trial_id)
183 configs.append(
184 self._get_key_val(
185 conn,
186 self._schema.config_param,
187 "param",
188 config_id=trial.config_id,
189 )
190 )
191 if stat.is_succeeded():
192 scores.append(
193 self._get_key_val(
194 conn,
195 self._schema.trial_result,
196 "metric",
197 exp_id=self._experiment_id,
198 trial_id=trial.trial_id,
199 )
200 )
201 else:
202 scores.append(None)
204 return (trial_ids, configs, scores, status)
206 @staticmethod
207 def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> Dict[str, Any]:
208 """
209 Helper method to retrieve key-value pairs from the database.
211 (E.g., configurations, results, and telemetry).
212 """
213 cur_result: CursorResult[Tuple[str, Any]] = conn.execute(
214 select(
215 column(f"{field}_id"),
216 column(f"{field}_value"),
217 )
218 .select_from(table)
219 .where(*[column(key) == val for (key, val) in kwargs.items()])
220 )
221 # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to
222 # avoid naming conflicts.
223 return dict(
224 row._tuple() for row in cur_result.fetchall() # pylint: disable=protected-access
225 )
227 @staticmethod
228 def _save_params(
229 conn: Connection,
230 table: Table,
231 params: Dict[str, Any],
232 **kwargs: Any,
233 ) -> None:
234 if not params:
235 return
236 conn.execute(
237 table.insert(),
238 [
239 {**kwargs, "param_id": key, "param_value": nullable(str, val)}
240 for (key, val) in params.items()
241 ],
242 )
244 def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]:
245 timestamp = utcify_timestamp(timestamp, origin="local")
246 _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp)
247 if running:
248 pending_status = ["PENDING", "READY", "RUNNING"]
249 else:
250 pending_status = ["PENDING"]
251 with self._engine.connect() as conn:
252 cur_trials = conn.execute(
253 self._schema.trial.select().where(
254 self._schema.trial.c.exp_id == self._experiment_id,
255 (
256 self._schema.trial.c.ts_start.is_(None)
257 | (self._schema.trial.c.ts_start <= timestamp)
258 ),
259 self._schema.trial.c.ts_end.is_(None),
260 self._schema.trial.c.status.in_(pending_status),
261 )
262 )
263 for trial in cur_trials.fetchall():
264 tunables = self._get_key_val(
265 conn,
266 self._schema.config_param,
267 "param",
268 config_id=trial.config_id,
269 )
270 config = self._get_key_val(
271 conn,
272 self._schema.trial_param,
273 "param",
274 exp_id=self._experiment_id,
275 trial_id=trial.trial_id,
276 )
277 yield Trial(
278 engine=self._engine,
279 schema=self._schema,
280 # Reset .is_updated flag after the assignment:
281 tunables=self._tunables.copy().assign(tunables).reset(),
282 experiment_id=self._experiment_id,
283 trial_id=trial.trial_id,
284 config_id=trial.config_id,
285 opt_targets=self._opt_targets,
286 config=config,
287 )
289 def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int:
290 """
291 Get the config ID for the given tunables.
293 If the config does not exist, create a new record for it.
294 """
295 config_hash = hashlib.sha256(str(tunables).encode("utf-8")).hexdigest()
296 cur_config = conn.execute(
297 self._schema.config.select().where(self._schema.config.c.config_hash == config_hash)
298 ).fetchone()
299 if cur_config is not None:
300 return int(cur_config.config_id) # mypy doesn't know it's always int
301 # Config not found, create a new one:
302 config_id: int = conn.execute(
303 self._schema.config.insert().values(config_hash=config_hash)
304 ).inserted_primary_key[0]
305 self._save_params(
306 conn,
307 self._schema.config_param,
308 {tunable.name: tunable.value for (tunable, _group) in tunables},
309 config_id=config_id,
310 )
311 return config_id
313 def _new_trial(
314 self,
315 tunables: TunableGroups,
316 ts_start: Optional[datetime] = None,
317 config: Optional[Dict[str, Any]] = None,
318 ) -> Storage.Trial:
319 # MySQL can round microseconds into the future causing scheduler to skip trials.
320 # Truncate microseconds to avoid this issue.
321 ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local").replace(
322 microsecond=0
323 )
324 _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start)
325 with self._engine.begin() as conn:
326 try:
327 config_id = self._get_config_id(conn, tunables)
328 conn.execute(
329 self._schema.trial.insert().values(
330 exp_id=self._experiment_id,
331 trial_id=self._trial_id,
332 config_id=config_id,
333 ts_start=ts_start,
334 status="PENDING",
335 )
336 )
338 # Note: config here is the framework config, not the target
339 # environment config (i.e., tunables).
340 if config is not None:
341 self._save_params(
342 conn,
343 self._schema.trial_param,
344 config,
345 exp_id=self._experiment_id,
346 trial_id=self._trial_id,
347 )
349 trial = Trial(
350 engine=self._engine,
351 schema=self._schema,
352 tunables=tunables,
353 experiment_id=self._experiment_id,
354 trial_id=self._trial_id,
355 config_id=config_id,
356 opt_targets=self._opt_targets,
357 config=config,
358 )
359 self._trial_id += 1
360 return trial
361 except Exception:
362 conn.rollback()
363 raise