Coverage for mlos_bench/mlos_bench/storage/sql/trial.py: 99%
72 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5""":py:class:`.Storage.Trial` interface implementation for saving and restoring
6the benchmark trial data using `SQLAlchemy <https://sqlalchemy.org>`_ backend.
7"""
10import logging
11from collections.abc import Mapping
12from datetime import datetime
13from typing import Any, Literal
15from sqlalchemy import or_
16from sqlalchemy.engine import Connection, Engine
17from sqlalchemy.exc import IntegrityError
19from mlos_bench.environments.status import Status
20from mlos_bench.storage.base_storage import Storage
21from mlos_bench.storage.sql.common import save_params
22from mlos_bench.storage.sql.schema import DbSchema
23from mlos_bench.tunables.tunable_groups import TunableGroups
24from mlos_bench.util import nullable, utcify_timestamp
26_LOG = logging.getLogger(__name__)
29class Trial(Storage.Trial):
30 """Store the results of a single run of the experiment in SQL database."""
32 def __init__( # pylint: disable=too-many-arguments
33 self,
34 *,
35 engine: Engine,
36 schema: DbSchema,
37 tunables: TunableGroups,
38 experiment_id: str,
39 trial_id: int,
40 config_id: int,
41 trial_runner_id: int | None = None,
42 opt_targets: dict[str, Literal["min", "max"]],
43 config: dict[str, Any] | None = None,
44 status: Status = Status.UNKNOWN,
45 ):
46 super().__init__(
47 tunables=tunables,
48 experiment_id=experiment_id,
49 trial_id=trial_id,
50 tunable_config_id=config_id,
51 trial_runner_id=trial_runner_id,
52 opt_targets=opt_targets,
53 config=config,
54 status=status,
55 )
56 self._engine = engine
57 self._schema = schema
59 def set_trial_runner(self, trial_runner_id: int) -> int:
60 trial_runner_id = super().set_trial_runner(trial_runner_id)
61 with self._engine.begin() as conn:
62 conn.execute(
63 self._schema.trial.update()
64 .where(
65 self._schema.trial.c.exp_id == self._experiment_id,
66 self._schema.trial.c.trial_id == self._trial_id,
67 (
68 or_(
69 self._schema.trial.c.trial_runner_id.is_(None),
70 self._schema.trial.c.status == Status.PENDING.name,
71 )
72 ),
73 )
74 .values(
75 trial_runner_id=trial_runner_id,
76 )
77 )
78 # Guard against concurrent updates.
79 with self._engine.begin() as conn:
80 trial_runner_rs = conn.execute(
81 self._schema.trial.select()
82 .with_only_columns(
83 self._schema.trial.c.trial_runner_id,
84 )
85 .where(
86 self._schema.trial.c.exp_id == self._experiment_id,
87 self._schema.trial.c.trial_id == self._trial_id,
88 )
89 )
90 trial_runner_row = trial_runner_rs.fetchone()
91 assert trial_runner_row
92 self._trial_runner_id = trial_runner_row.trial_runner_id
93 assert isinstance(self._trial_runner_id, int)
94 return self._trial_runner_id
96 def _save_new_config_data(self, new_config_data: Mapping[str, int | float | str]) -> None:
97 with self._engine.begin() as conn:
98 save_params(
99 conn,
100 self._schema.trial_param,
101 new_config_data,
102 exp_id=self._experiment_id,
103 trial_id=self._trial_id,
104 )
106 def update(
107 self,
108 status: Status,
109 timestamp: datetime,
110 metrics: dict[str, Any] | None = None,
111 ) -> dict[str, Any] | None:
112 # Make sure to convert the timestamp to UTC before storing it in the database.
113 timestamp = utcify_timestamp(timestamp, origin="local")
114 metrics = super().update(status, timestamp, metrics)
115 with self._engine.begin() as conn:
116 self._update_status(conn, status, timestamp)
117 try:
118 if status.is_completed():
119 # Final update of the status and ts_end:
120 cur_status = conn.execute(
121 self._schema.trial.update()
122 .where(
123 self._schema.trial.c.exp_id == self._experiment_id,
124 self._schema.trial.c.trial_id == self._trial_id,
125 self._schema.trial.c.ts_end.is_(None),
126 self._schema.trial.c.status.notin_(
127 [
128 Status.SUCCEEDED.name,
129 Status.CANCELED.name,
130 Status.FAILED.name,
131 Status.TIMED_OUT.name,
132 ]
133 ),
134 )
135 .values(
136 status=status.name,
137 ts_end=timestamp,
138 )
139 )
140 if cur_status.rowcount not in {1, -1}:
141 _LOG.warning("Trial %s :: update failed: %s", self, status)
142 raise RuntimeError(
143 f"Failed to update the status of the trial {self} to {status}. "
144 f"({cur_status.rowcount} rows)"
145 )
146 if metrics:
147 conn.execute(
148 self._schema.trial_result.insert().values(
149 [
150 {
151 "exp_id": self._experiment_id,
152 "trial_id": self._trial_id,
153 "metric_id": key,
154 "metric_value": nullable(str, val),
155 }
156 for (key, val) in metrics.items()
157 ]
158 )
159 )
160 else:
161 # Update of the status and ts_start when starting the trial:
162 assert metrics is None, f"Unexpected metrics for status: {status}"
163 cur_status = conn.execute(
164 self._schema.trial.update()
165 .where(
166 self._schema.trial.c.exp_id == self._experiment_id,
167 self._schema.trial.c.trial_id == self._trial_id,
168 self._schema.trial.c.ts_end.is_(None),
169 self._schema.trial.c.status.notin_(
170 [
171 Status.RUNNING.name,
172 Status.SUCCEEDED.name,
173 Status.CANCELED.name,
174 Status.FAILED.name,
175 Status.TIMED_OUT.name,
176 ]
177 ),
178 )
179 .values(
180 status=status.name,
181 ts_start=timestamp,
182 )
183 )
184 if cur_status.rowcount not in {1, -1}:
185 # Keep the old status and timestamp if already running, but log it.
186 _LOG.warning("Trial %s :: cannot be updated to: %s", self, status)
187 except Exception:
188 conn.rollback()
189 raise
190 return metrics
192 def update_telemetry(
193 self,
194 status: Status,
195 timestamp: datetime,
196 metrics: list[tuple[datetime, str, Any]],
197 ) -> None:
198 super().update_telemetry(status, timestamp, metrics)
199 # Make sure to convert the timestamp to UTC before storing it in the database.
200 timestamp = utcify_timestamp(timestamp, origin="local")
201 metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics]
202 # NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()`
203 # and we need to keep `.update_telemetry()` idempotent; hence a loop instead of
204 # a bulk upsert.
205 # See Also: comments in <https://github.com/microsoft/MLOS/pull/466>
206 with self._engine.begin() as conn:
207 self._update_status(conn, status, timestamp)
208 for metric_ts, key, val in metrics:
209 with self._engine.begin() as conn:
210 try:
211 conn.execute(
212 self._schema.trial_telemetry.insert().values(
213 exp_id=self._experiment_id,
214 trial_id=self._trial_id,
215 ts=metric_ts,
216 metric_id=key,
217 metric_value=nullable(str, val),
218 )
219 )
220 except IntegrityError as ex:
221 _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex)
223 def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None:
224 """
225 Insert a new status record into the database.
227 This call is idempotent.
228 """
229 # Make sure to convert the timestamp to UTC before storing it in the database.
230 timestamp = utcify_timestamp(timestamp, origin="local")
231 try:
232 conn.execute(
233 self._schema.trial_status.insert().values(
234 exp_id=self._experiment_id,
235 trial_id=self._trial_id,
236 ts=timestamp,
237 status=status.name,
238 )
239 )
240 except IntegrityError as ex:
241 _LOG.warning(
242 "Status with that timestamp already exists: %s %s :: %s",
243 self,
244 timestamp,
245 ex,
246 )