Coverage for mlos_bench/mlos_bench/storage/sql/trial.py: 99%

72 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""":py:class:`.Storage.Trial` interface implementation for saving and restoring 

6the benchmark trial data using `SQLAlchemy <https://sqlalchemy.org>`_ backend. 

7""" 

8 

9 

10import logging 

11from collections.abc import Mapping 

12from datetime import datetime 

13from typing import Any, Literal 

14 

15from sqlalchemy import or_ 

16from sqlalchemy.engine import Connection, Engine 

17from sqlalchemy.exc import IntegrityError 

18 

19from mlos_bench.environments.status import Status 

20from mlos_bench.storage.base_storage import Storage 

21from mlos_bench.storage.sql.common import save_params 

22from mlos_bench.storage.sql.schema import DbSchema 

23from mlos_bench.tunables.tunable_groups import TunableGroups 

24from mlos_bench.util import nullable, utcify_timestamp 

25 

26_LOG = logging.getLogger(__name__) 

27 

28 

29class Trial(Storage.Trial): 

30 """Store the results of a single run of the experiment in SQL database.""" 

31 

32 def __init__( # pylint: disable=too-many-arguments 

33 self, 

34 *, 

35 engine: Engine, 

36 schema: DbSchema, 

37 tunables: TunableGroups, 

38 experiment_id: str, 

39 trial_id: int, 

40 config_id: int, 

41 trial_runner_id: int | None = None, 

42 opt_targets: dict[str, Literal["min", "max"]], 

43 config: dict[str, Any] | None = None, 

44 status: Status = Status.UNKNOWN, 

45 ): 

46 super().__init__( 

47 tunables=tunables, 

48 experiment_id=experiment_id, 

49 trial_id=trial_id, 

50 tunable_config_id=config_id, 

51 trial_runner_id=trial_runner_id, 

52 opt_targets=opt_targets, 

53 config=config, 

54 status=status, 

55 ) 

56 self._engine = engine 

57 self._schema = schema 

58 

59 def set_trial_runner(self, trial_runner_id: int) -> int: 

60 trial_runner_id = super().set_trial_runner(trial_runner_id) 

61 with self._engine.begin() as conn: 

62 conn.execute( 

63 self._schema.trial.update() 

64 .where( 

65 self._schema.trial.c.exp_id == self._experiment_id, 

66 self._schema.trial.c.trial_id == self._trial_id, 

67 ( 

68 or_( 

69 self._schema.trial.c.trial_runner_id.is_(None), 

70 self._schema.trial.c.status == Status.PENDING.name, 

71 ) 

72 ), 

73 ) 

74 .values( 

75 trial_runner_id=trial_runner_id, 

76 ) 

77 ) 

78 # Guard against concurrent updates. 

79 with self._engine.begin() as conn: 

80 trial_runner_rs = conn.execute( 

81 self._schema.trial.select() 

82 .with_only_columns( 

83 self._schema.trial.c.trial_runner_id, 

84 ) 

85 .where( 

86 self._schema.trial.c.exp_id == self._experiment_id, 

87 self._schema.trial.c.trial_id == self._trial_id, 

88 ) 

89 ) 

90 trial_runner_row = trial_runner_rs.fetchone() 

91 assert trial_runner_row 

92 self._trial_runner_id = trial_runner_row.trial_runner_id 

93 assert isinstance(self._trial_runner_id, int) 

94 return self._trial_runner_id 

95 

96 def _save_new_config_data(self, new_config_data: Mapping[str, int | float | str]) -> None: 

97 with self._engine.begin() as conn: 

98 save_params( 

99 conn, 

100 self._schema.trial_param, 

101 new_config_data, 

102 exp_id=self._experiment_id, 

103 trial_id=self._trial_id, 

104 ) 

105 

106 def update( 

107 self, 

108 status: Status, 

109 timestamp: datetime, 

110 metrics: dict[str, Any] | None = None, 

111 ) -> dict[str, Any] | None: 

112 # Make sure to convert the timestamp to UTC before storing it in the database. 

113 timestamp = utcify_timestamp(timestamp, origin="local") 

114 metrics = super().update(status, timestamp, metrics) 

115 with self._engine.begin() as conn: 

116 self._update_status(conn, status, timestamp) 

117 try: 

118 if status.is_completed(): 

119 # Final update of the status and ts_end: 

120 cur_status = conn.execute( 

121 self._schema.trial.update() 

122 .where( 

123 self._schema.trial.c.exp_id == self._experiment_id, 

124 self._schema.trial.c.trial_id == self._trial_id, 

125 self._schema.trial.c.ts_end.is_(None), 

126 self._schema.trial.c.status.notin_( 

127 [ 

128 Status.SUCCEEDED.name, 

129 Status.CANCELED.name, 

130 Status.FAILED.name, 

131 Status.TIMED_OUT.name, 

132 ] 

133 ), 

134 ) 

135 .values( 

136 status=status.name, 

137 ts_end=timestamp, 

138 ) 

139 ) 

140 if cur_status.rowcount not in {1, -1}: 

141 _LOG.warning("Trial %s :: update failed: %s", self, status) 

142 raise RuntimeError( 

143 f"Failed to update the status of the trial {self} to {status}. " 

144 f"({cur_status.rowcount} rows)" 

145 ) 

146 if metrics: 

147 conn.execute( 

148 self._schema.trial_result.insert().values( 

149 [ 

150 { 

151 "exp_id": self._experiment_id, 

152 "trial_id": self._trial_id, 

153 "metric_id": key, 

154 "metric_value": nullable(str, val), 

155 } 

156 for (key, val) in metrics.items() 

157 ] 

158 ) 

159 ) 

160 else: 

161 # Update of the status and ts_start when starting the trial: 

162 assert metrics is None, f"Unexpected metrics for status: {status}" 

163 cur_status = conn.execute( 

164 self._schema.trial.update() 

165 .where( 

166 self._schema.trial.c.exp_id == self._experiment_id, 

167 self._schema.trial.c.trial_id == self._trial_id, 

168 self._schema.trial.c.ts_end.is_(None), 

169 self._schema.trial.c.status.notin_( 

170 [ 

171 Status.RUNNING.name, 

172 Status.SUCCEEDED.name, 

173 Status.CANCELED.name, 

174 Status.FAILED.name, 

175 Status.TIMED_OUT.name, 

176 ] 

177 ), 

178 ) 

179 .values( 

180 status=status.name, 

181 ts_start=timestamp, 

182 ) 

183 ) 

184 if cur_status.rowcount not in {1, -1}: 

185 # Keep the old status and timestamp if already running, but log it. 

186 _LOG.warning("Trial %s :: cannot be updated to: %s", self, status) 

187 except Exception: 

188 conn.rollback() 

189 raise 

190 return metrics 

191 

192 def update_telemetry( 

193 self, 

194 status: Status, 

195 timestamp: datetime, 

196 metrics: list[tuple[datetime, str, Any]], 

197 ) -> None: 

198 super().update_telemetry(status, timestamp, metrics) 

199 # Make sure to convert the timestamp to UTC before storing it in the database. 

200 timestamp = utcify_timestamp(timestamp, origin="local") 

201 metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics] 

202 # NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()` 

203 # and we need to keep `.update_telemetry()` idempotent; hence a loop instead of 

204 # a bulk upsert. 

205 # See Also: comments in <https://github.com/microsoft/MLOS/pull/466> 

206 with self._engine.begin() as conn: 

207 self._update_status(conn, status, timestamp) 

208 for metric_ts, key, val in metrics: 

209 with self._engine.begin() as conn: 

210 try: 

211 conn.execute( 

212 self._schema.trial_telemetry.insert().values( 

213 exp_id=self._experiment_id, 

214 trial_id=self._trial_id, 

215 ts=metric_ts, 

216 metric_id=key, 

217 metric_value=nullable(str, val), 

218 ) 

219 ) 

220 except IntegrityError as ex: 

221 _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex) 

222 

223 def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None: 

224 """ 

225 Insert a new status record into the database. 

226 

227 This call is idempotent. 

228 """ 

229 # Make sure to convert the timestamp to UTC before storing it in the database. 

230 timestamp = utcify_timestamp(timestamp, origin="local") 

231 try: 

232 conn.execute( 

233 self._schema.trial_status.insert().values( 

234 exp_id=self._experiment_id, 

235 trial_id=self._trial_id, 

236 ts=timestamp, 

237 status=status.name, 

238 ) 

239 ) 

240 except IntegrityError as ex: 

241 _LOG.warning( 

242 "Status with that timestamp already exists: %s %s :: %s", 

243 self, 

244 timestamp, 

245 ex, 

246 )