Coverage for mlos_bench/mlos_bench/storage/sql/trial.py: 98%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Saving and updating benchmark data using SQLAlchemy backend.""" 

6 

7import logging 

8from datetime import datetime 

9from typing import Any, Dict, List, Literal, Optional, Tuple 

10 

11from sqlalchemy import Connection, Engine 

12from sqlalchemy.exc import IntegrityError 

13 

14from mlos_bench.environments.status import Status 

15from mlos_bench.storage.base_storage import Storage 

16from mlos_bench.storage.sql.schema import DbSchema 

17from mlos_bench.tunables.tunable_groups import TunableGroups 

18from mlos_bench.util import nullable, utcify_timestamp 

19 

20_LOG = logging.getLogger(__name__) 

21 

22 

23class Trial(Storage.Trial): 

24 """Store the results of a single run of the experiment in SQL database.""" 

25 

26 def __init__( # pylint: disable=too-many-arguments 

27 self, 

28 *, 

29 engine: Engine, 

30 schema: DbSchema, 

31 tunables: TunableGroups, 

32 experiment_id: str, 

33 trial_id: int, 

34 config_id: int, 

35 opt_targets: Dict[str, Literal["min", "max"]], 

36 config: Optional[Dict[str, Any]] = None, 

37 ): 

38 super().__init__( 

39 tunables=tunables, 

40 experiment_id=experiment_id, 

41 trial_id=trial_id, 

42 tunable_config_id=config_id, 

43 opt_targets=opt_targets, 

44 config=config, 

45 ) 

46 self._engine = engine 

47 self._schema = schema 

48 

49 def update( 

50 self, 

51 status: Status, 

52 timestamp: datetime, 

53 metrics: Optional[Dict[str, Any]] = None, 

54 ) -> Optional[Dict[str, Any]]: 

55 # Make sure to convert the timestamp to UTC before storing it in the database. 

56 timestamp = utcify_timestamp(timestamp, origin="local") 

57 metrics = super().update(status, timestamp, metrics) 

58 with self._engine.begin() as conn: 

59 self._update_status(conn, status, timestamp) 

60 try: 

61 if status.is_completed(): 

62 # Final update of the status and ts_end: 

63 cur_status = conn.execute( 

64 self._schema.trial.update() 

65 .where( 

66 self._schema.trial.c.exp_id == self._experiment_id, 

67 self._schema.trial.c.trial_id == self._trial_id, 

68 self._schema.trial.c.ts_end.is_(None), 

69 self._schema.trial.c.status.notin_( 

70 ["SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"] 

71 ), 

72 ) 

73 .values( 

74 status=status.name, 

75 ts_end=timestamp, 

76 ) 

77 ) 

78 if cur_status.rowcount not in {1, -1}: 

79 _LOG.warning("Trial %s :: update failed: %s", self, status) 

80 raise RuntimeError( 

81 f"Failed to update the status of the trial {self} to {status}. " 

82 f"({cur_status.rowcount} rows)" 

83 ) 

84 if metrics: 

85 conn.execute( 

86 self._schema.trial_result.insert().values( 

87 [ 

88 { 

89 "exp_id": self._experiment_id, 

90 "trial_id": self._trial_id, 

91 "metric_id": key, 

92 "metric_value": nullable(str, val), 

93 } 

94 for (key, val) in metrics.items() 

95 ] 

96 ) 

97 ) 

98 else: 

99 # Update of the status and ts_start when starting the trial: 

100 assert metrics is None, f"Unexpected metrics for status: {status}" 

101 cur_status = conn.execute( 

102 self._schema.trial.update() 

103 .where( 

104 self._schema.trial.c.exp_id == self._experiment_id, 

105 self._schema.trial.c.trial_id == self._trial_id, 

106 self._schema.trial.c.ts_end.is_(None), 

107 self._schema.trial.c.status.notin_( 

108 ["RUNNING", "SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"] 

109 ), 

110 ) 

111 .values( 

112 status=status.name, 

113 ts_start=timestamp, 

114 ) 

115 ) 

116 if cur_status.rowcount not in {1, -1}: 

117 # Keep the old status and timestamp if already running, but log it. 

118 _LOG.warning("Trial %s :: cannot be updated to: %s", self, status) 

119 except Exception: 

120 conn.rollback() 

121 raise 

122 return metrics 

123 

124 def update_telemetry( 

125 self, 

126 status: Status, 

127 timestamp: datetime, 

128 metrics: List[Tuple[datetime, str, Any]], 

129 ) -> None: 

130 super().update_telemetry(status, timestamp, metrics) 

131 # Make sure to convert the timestamp to UTC before storing it in the database. 

132 timestamp = utcify_timestamp(timestamp, origin="local") 

133 metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics] 

134 # NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()` 

135 # and we need to keep `.update_telemetry()` idempotent; hence a loop instead of 

136 # a bulk upsert. 

137 # See Also: comments in <https://github.com/microsoft/MLOS/pull/466> 

138 with self._engine.begin() as conn: 

139 self._update_status(conn, status, timestamp) 

140 for metric_ts, key, val in metrics: 

141 with self._engine.begin() as conn: 

142 try: 

143 conn.execute( 

144 self._schema.trial_telemetry.insert().values( 

145 exp_id=self._experiment_id, 

146 trial_id=self._trial_id, 

147 ts=metric_ts, 

148 metric_id=key, 

149 metric_value=nullable(str, val), 

150 ) 

151 ) 

152 except IntegrityError as ex: 

153 _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex) 

154 

155 def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None: 

156 """ 

157 Insert a new status record into the database. 

158 

159 This call is idempotent. 

160 """ 

161 # Make sure to convert the timestamp to UTC before storing it in the database. 

162 timestamp = utcify_timestamp(timestamp, origin="local") 

163 try: 

164 conn.execute( 

165 self._schema.trial_status.insert().values( 

166 exp_id=self._experiment_id, 

167 trial_id=self._trial_id, 

168 ts=timestamp, 

169 status=status.name, 

170 ) 

171 ) 

172 except IntegrityError as ex: 

173 _LOG.warning( 

174 "Status with that timestamp already exists: %s %s :: %s", 

175 self, 

176 timestamp, 

177 ex, 

178 )