Coverage for mlos_bench/mlos_bench/storage/sql/trial.py: 98%
55 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Saving and updating benchmark data using SQLAlchemy backend."""
7import logging
8from datetime import datetime
9from typing import Any, Dict, List, Literal, Optional, Tuple
11from sqlalchemy import Connection, Engine
12from sqlalchemy.exc import IntegrityError
14from mlos_bench.environments.status import Status
15from mlos_bench.storage.base_storage import Storage
16from mlos_bench.storage.sql.schema import DbSchema
17from mlos_bench.tunables.tunable_groups import TunableGroups
18from mlos_bench.util import nullable, utcify_timestamp
20_LOG = logging.getLogger(__name__)
23class Trial(Storage.Trial):
24 """Store the results of a single run of the experiment in SQL database."""
26 def __init__( # pylint: disable=too-many-arguments
27 self,
28 *,
29 engine: Engine,
30 schema: DbSchema,
31 tunables: TunableGroups,
32 experiment_id: str,
33 trial_id: int,
34 config_id: int,
35 opt_targets: Dict[str, Literal["min", "max"]],
36 config: Optional[Dict[str, Any]] = None,
37 ):
38 super().__init__(
39 tunables=tunables,
40 experiment_id=experiment_id,
41 trial_id=trial_id,
42 tunable_config_id=config_id,
43 opt_targets=opt_targets,
44 config=config,
45 )
46 self._engine = engine
47 self._schema = schema
49 def update(
50 self,
51 status: Status,
52 timestamp: datetime,
53 metrics: Optional[Dict[str, Any]] = None,
54 ) -> Optional[Dict[str, Any]]:
55 # Make sure to convert the timestamp to UTC before storing it in the database.
56 timestamp = utcify_timestamp(timestamp, origin="local")
57 metrics = super().update(status, timestamp, metrics)
58 with self._engine.begin() as conn:
59 self._update_status(conn, status, timestamp)
60 try:
61 if status.is_completed():
62 # Final update of the status and ts_end:
63 cur_status = conn.execute(
64 self._schema.trial.update()
65 .where(
66 self._schema.trial.c.exp_id == self._experiment_id,
67 self._schema.trial.c.trial_id == self._trial_id,
68 self._schema.trial.c.ts_end.is_(None),
69 self._schema.trial.c.status.notin_(
70 ["SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"]
71 ),
72 )
73 .values(
74 status=status.name,
75 ts_end=timestamp,
76 )
77 )
78 if cur_status.rowcount not in {1, -1}:
79 _LOG.warning("Trial %s :: update failed: %s", self, status)
80 raise RuntimeError(
81 f"Failed to update the status of the trial {self} to {status}. "
82 f"({cur_status.rowcount} rows)"
83 )
84 if metrics:
85 conn.execute(
86 self._schema.trial_result.insert().values(
87 [
88 {
89 "exp_id": self._experiment_id,
90 "trial_id": self._trial_id,
91 "metric_id": key,
92 "metric_value": nullable(str, val),
93 }
94 for (key, val) in metrics.items()
95 ]
96 )
97 )
98 else:
99 # Update of the status and ts_start when starting the trial:
100 assert metrics is None, f"Unexpected metrics for status: {status}"
101 cur_status = conn.execute(
102 self._schema.trial.update()
103 .where(
104 self._schema.trial.c.exp_id == self._experiment_id,
105 self._schema.trial.c.trial_id == self._trial_id,
106 self._schema.trial.c.ts_end.is_(None),
107 self._schema.trial.c.status.notin_(
108 ["RUNNING", "SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"]
109 ),
110 )
111 .values(
112 status=status.name,
113 ts_start=timestamp,
114 )
115 )
116 if cur_status.rowcount not in {1, -1}:
117 # Keep the old status and timestamp if already running, but log it.
118 _LOG.warning("Trial %s :: cannot be updated to: %s", self, status)
119 except Exception:
120 conn.rollback()
121 raise
122 return metrics
124 def update_telemetry(
125 self,
126 status: Status,
127 timestamp: datetime,
128 metrics: List[Tuple[datetime, str, Any]],
129 ) -> None:
130 super().update_telemetry(status, timestamp, metrics)
131 # Make sure to convert the timestamp to UTC before storing it in the database.
132 timestamp = utcify_timestamp(timestamp, origin="local")
133 metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics]
134 # NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()`
135 # and we need to keep `.update_telemetry()` idempotent; hence a loop instead of
136 # a bulk upsert.
137 # See Also: comments in <https://github.com/microsoft/MLOS/pull/466>
138 with self._engine.begin() as conn:
139 self._update_status(conn, status, timestamp)
140 for metric_ts, key, val in metrics:
141 with self._engine.begin() as conn:
142 try:
143 conn.execute(
144 self._schema.trial_telemetry.insert().values(
145 exp_id=self._experiment_id,
146 trial_id=self._trial_id,
147 ts=metric_ts,
148 metric_id=key,
149 metric_value=nullable(str, val),
150 )
151 )
152 except IntegrityError as ex:
153 _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex)
155 def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None:
156 """
157 Insert a new status record into the database.
159 This call is idempotent.
160 """
161 # Make sure to convert the timestamp to UTC before storing it in the database.
162 timestamp = utcify_timestamp(timestamp, origin="local")
163 try:
164 conn.execute(
165 self._schema.trial_status.insert().values(
166 exp_id=self._experiment_id,
167 trial_id=self._trial_id,
168 ts=timestamp,
169 status=status.name,
170 )
171 )
172 except IntegrityError as ex:
173 _LOG.warning(
174 "Status with that timestamp already exists: %s %s :: %s",
175 self,
176 timestamp,
177 ex,
178 )