Coverage for mlos_bench/mlos_bench/storage/sql/trial.py: 98%
55 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5""":py:class:`.Storage.Trial` interface implementation for saving and restoring
6the benchmark trial data using `SQLAlchemy <https://sqlalchemy.org>`_ backend.
7"""
10import logging
11from datetime import datetime
12from typing import Any, Dict, List, Literal, Optional, Tuple
14from sqlalchemy.engine import Connection, Engine
15from sqlalchemy.exc import IntegrityError
17from mlos_bench.environments.status import Status
18from mlos_bench.storage.base_storage import Storage
19from mlos_bench.storage.sql.schema import DbSchema
20from mlos_bench.tunables.tunable_groups import TunableGroups
21from mlos_bench.util import nullable, utcify_timestamp
23_LOG = logging.getLogger(__name__)
26class Trial(Storage.Trial):
27 """Store the results of a single run of the experiment in SQL database."""
29 def __init__( # pylint: disable=too-many-arguments
30 self,
31 *,
32 engine: Engine,
33 schema: DbSchema,
34 tunables: TunableGroups,
35 experiment_id: str,
36 trial_id: int,
37 config_id: int,
38 opt_targets: Dict[str, Literal["min", "max"]],
39 config: Optional[Dict[str, Any]] = None,
40 ):
41 super().__init__(
42 tunables=tunables,
43 experiment_id=experiment_id,
44 trial_id=trial_id,
45 tunable_config_id=config_id,
46 opt_targets=opt_targets,
47 config=config,
48 )
49 self._engine = engine
50 self._schema = schema
52 def update(
53 self,
54 status: Status,
55 timestamp: datetime,
56 metrics: Optional[Dict[str, Any]] = None,
57 ) -> Optional[Dict[str, Any]]:
58 # Make sure to convert the timestamp to UTC before storing it in the database.
59 timestamp = utcify_timestamp(timestamp, origin="local")
60 metrics = super().update(status, timestamp, metrics)
61 with self._engine.begin() as conn:
62 self._update_status(conn, status, timestamp)
63 try:
64 if status.is_completed():
65 # Final update of the status and ts_end:
66 cur_status = conn.execute(
67 self._schema.trial.update()
68 .where(
69 self._schema.trial.c.exp_id == self._experiment_id,
70 self._schema.trial.c.trial_id == self._trial_id,
71 self._schema.trial.c.ts_end.is_(None),
72 self._schema.trial.c.status.notin_(
73 ["SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"]
74 ),
75 )
76 .values(
77 status=status.name,
78 ts_end=timestamp,
79 )
80 )
81 if cur_status.rowcount not in {1, -1}:
82 _LOG.warning("Trial %s :: update failed: %s", self, status)
83 raise RuntimeError(
84 f"Failed to update the status of the trial {self} to {status}. "
85 f"({cur_status.rowcount} rows)"
86 )
87 if metrics:
88 conn.execute(
89 self._schema.trial_result.insert().values(
90 [
91 {
92 "exp_id": self._experiment_id,
93 "trial_id": self._trial_id,
94 "metric_id": key,
95 "metric_value": nullable(str, val),
96 }
97 for (key, val) in metrics.items()
98 ]
99 )
100 )
101 else:
102 # Update of the status and ts_start when starting the trial:
103 assert metrics is None, f"Unexpected metrics for status: {status}"
104 cur_status = conn.execute(
105 self._schema.trial.update()
106 .where(
107 self._schema.trial.c.exp_id == self._experiment_id,
108 self._schema.trial.c.trial_id == self._trial_id,
109 self._schema.trial.c.ts_end.is_(None),
110 self._schema.trial.c.status.notin_(
111 ["RUNNING", "SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"]
112 ),
113 )
114 .values(
115 status=status.name,
116 ts_start=timestamp,
117 )
118 )
119 if cur_status.rowcount not in {1, -1}:
120 # Keep the old status and timestamp if already running, but log it.
121 _LOG.warning("Trial %s :: cannot be updated to: %s", self, status)
122 except Exception:
123 conn.rollback()
124 raise
125 return metrics
127 def update_telemetry(
128 self,
129 status: Status,
130 timestamp: datetime,
131 metrics: List[Tuple[datetime, str, Any]],
132 ) -> None:
133 super().update_telemetry(status, timestamp, metrics)
134 # Make sure to convert the timestamp to UTC before storing it in the database.
135 timestamp = utcify_timestamp(timestamp, origin="local")
136 metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics]
137 # NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()`
138 # and we need to keep `.update_telemetry()` idempotent; hence a loop instead of
139 # a bulk upsert.
140 # See Also: comments in <https://github.com/microsoft/MLOS/pull/466>
141 with self._engine.begin() as conn:
142 self._update_status(conn, status, timestamp)
143 for metric_ts, key, val in metrics:
144 with self._engine.begin() as conn:
145 try:
146 conn.execute(
147 self._schema.trial_telemetry.insert().values(
148 exp_id=self._experiment_id,
149 trial_id=self._trial_id,
150 ts=metric_ts,
151 metric_id=key,
152 metric_value=nullable(str, val),
153 )
154 )
155 except IntegrityError as ex:
156 _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex)
158 def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None:
159 """
160 Insert a new status record into the database.
162 This call is idempotent.
163 """
164 # Make sure to convert the timestamp to UTC before storing it in the database.
165 timestamp = utcify_timestamp(timestamp, origin="local")
166 try:
167 conn.execute(
168 self._schema.trial_status.insert().values(
169 exp_id=self._experiment_id,
170 trial_id=self._trial_id,
171 ts=timestamp,
172 status=status.name,
173 )
174 )
175 except IntegrityError as ex:
176 _LOG.warning(
177 "Status with that timestamp already exists: %s %s :: %s",
178 self,
179 timestamp,
180 ex,
181 )