Coverage for mlos_bench/mlos_bench/storage/sql/trial.py: 98%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""":py:class:`.Storage.Trial` interface implementation for saving and restoring 

6the benchmark trial data using `SQLAlchemy <https://sqlalchemy.org>`_ backend. 

7""" 

8 

9 

10import logging 

11from datetime import datetime 

12from typing import Any, Dict, List, Literal, Optional, Tuple 

13 

14from sqlalchemy.engine import Connection, Engine 

15from sqlalchemy.exc import IntegrityError 

16 

17from mlos_bench.environments.status import Status 

18from mlos_bench.storage.base_storage import Storage 

19from mlos_bench.storage.sql.schema import DbSchema 

20from mlos_bench.tunables.tunable_groups import TunableGroups 

21from mlos_bench.util import nullable, utcify_timestamp 

22 

23_LOG = logging.getLogger(__name__) 

24 

25 

26class Trial(Storage.Trial): 

27 """Store the results of a single run of the experiment in SQL database.""" 

28 

29 def __init__( # pylint: disable=too-many-arguments 

30 self, 

31 *, 

32 engine: Engine, 

33 schema: DbSchema, 

34 tunables: TunableGroups, 

35 experiment_id: str, 

36 trial_id: int, 

37 config_id: int, 

38 opt_targets: Dict[str, Literal["min", "max"]], 

39 config: Optional[Dict[str, Any]] = None, 

40 ): 

41 super().__init__( 

42 tunables=tunables, 

43 experiment_id=experiment_id, 

44 trial_id=trial_id, 

45 tunable_config_id=config_id, 

46 opt_targets=opt_targets, 

47 config=config, 

48 ) 

49 self._engine = engine 

50 self._schema = schema 

51 

52 def update( 

53 self, 

54 status: Status, 

55 timestamp: datetime, 

56 metrics: Optional[Dict[str, Any]] = None, 

57 ) -> Optional[Dict[str, Any]]: 

58 # Make sure to convert the timestamp to UTC before storing it in the database. 

59 timestamp = utcify_timestamp(timestamp, origin="local") 

60 metrics = super().update(status, timestamp, metrics) 

61 with self._engine.begin() as conn: 

62 self._update_status(conn, status, timestamp) 

63 try: 

64 if status.is_completed(): 

65 # Final update of the status and ts_end: 

66 cur_status = conn.execute( 

67 self._schema.trial.update() 

68 .where( 

69 self._schema.trial.c.exp_id == self._experiment_id, 

70 self._schema.trial.c.trial_id == self._trial_id, 

71 self._schema.trial.c.ts_end.is_(None), 

72 self._schema.trial.c.status.notin_( 

73 ["SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"] 

74 ), 

75 ) 

76 .values( 

77 status=status.name, 

78 ts_end=timestamp, 

79 ) 

80 ) 

81 if cur_status.rowcount not in {1, -1}: 

82 _LOG.warning("Trial %s :: update failed: %s", self, status) 

83 raise RuntimeError( 

84 f"Failed to update the status of the trial {self} to {status}. " 

85 f"({cur_status.rowcount} rows)" 

86 ) 

87 if metrics: 

88 conn.execute( 

89 self._schema.trial_result.insert().values( 

90 [ 

91 { 

92 "exp_id": self._experiment_id, 

93 "trial_id": self._trial_id, 

94 "metric_id": key, 

95 "metric_value": nullable(str, val), 

96 } 

97 for (key, val) in metrics.items() 

98 ] 

99 ) 

100 ) 

101 else: 

102 # Update of the status and ts_start when starting the trial: 

103 assert metrics is None, f"Unexpected metrics for status: {status}" 

104 cur_status = conn.execute( 

105 self._schema.trial.update() 

106 .where( 

107 self._schema.trial.c.exp_id == self._experiment_id, 

108 self._schema.trial.c.trial_id == self._trial_id, 

109 self._schema.trial.c.ts_end.is_(None), 

110 self._schema.trial.c.status.notin_( 

111 ["RUNNING", "SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"] 

112 ), 

113 ) 

114 .values( 

115 status=status.name, 

116 ts_start=timestamp, 

117 ) 

118 ) 

119 if cur_status.rowcount not in {1, -1}: 

120 # Keep the old status and timestamp if already running, but log it. 

121 _LOG.warning("Trial %s :: cannot be updated to: %s", self, status) 

122 except Exception: 

123 conn.rollback() 

124 raise 

125 return metrics 

126 

127 def update_telemetry( 

128 self, 

129 status: Status, 

130 timestamp: datetime, 

131 metrics: List[Tuple[datetime, str, Any]], 

132 ) -> None: 

133 super().update_telemetry(status, timestamp, metrics) 

134 # Make sure to convert the timestamp to UTC before storing it in the database. 

135 timestamp = utcify_timestamp(timestamp, origin="local") 

136 metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics] 

137 # NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()` 

138 # and we need to keep `.update_telemetry()` idempotent; hence a loop instead of 

139 # a bulk upsert. 

140 # See Also: comments in <https://github.com/microsoft/MLOS/pull/466> 

141 with self._engine.begin() as conn: 

142 self._update_status(conn, status, timestamp) 

143 for metric_ts, key, val in metrics: 

144 with self._engine.begin() as conn: 

145 try: 

146 conn.execute( 

147 self._schema.trial_telemetry.insert().values( 

148 exp_id=self._experiment_id, 

149 trial_id=self._trial_id, 

150 ts=metric_ts, 

151 metric_id=key, 

152 metric_value=nullable(str, val), 

153 ) 

154 ) 

155 except IntegrityError as ex: 

156 _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex) 

157 

158 def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None: 

159 """ 

160 Insert a new status record into the database. 

161 

162 This call is idempotent. 

163 """ 

164 # Make sure to convert the timestamp to UTC before storing it in the database. 

165 timestamp = utcify_timestamp(timestamp, origin="local") 

166 try: 

167 conn.execute( 

168 self._schema.trial_status.insert().values( 

169 exp_id=self._experiment_id, 

170 trial_id=self._trial_id, 

171 ts=timestamp, 

172 status=status.name, 

173 ) 

174 ) 

175 except IntegrityError as ex: 

176 _LOG.warning( 

177 "Status with that timestamp already exists: %s %s :: %s", 

178 self, 

179 timestamp, 

180 ex, 

181 )