Coverage for mlos_bench/mlos_bench/storage/sql/experiment.py: 89%

100 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Saving and restoring the benchmark data using SQLAlchemy.""" 

6 

7import hashlib 

8import logging 

9from datetime import datetime 

10from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple 

11 

12from pytz import UTC 

13from sqlalchemy import Connection, CursorResult, Engine, Table, column, func, select 

14 

15from mlos_bench.environments.status import Status 

16from mlos_bench.storage.base_storage import Storage 

17from mlos_bench.storage.sql.schema import DbSchema 

18from mlos_bench.storage.sql.trial import Trial 

19from mlos_bench.tunables.tunable_groups import TunableGroups 

20from mlos_bench.util import nullable, utcify_timestamp 

21 

22_LOG = logging.getLogger(__name__) 

23 

24 

25class Experiment(Storage.Experiment): 

26 """Logic for retrieving and storing the results of a single experiment.""" 

27 

28 def __init__( # pylint: disable=too-many-arguments 

29 self, 

30 *, 

31 engine: Engine, 

32 schema: DbSchema, 

33 tunables: TunableGroups, 

34 experiment_id: str, 

35 trial_id: int, 

36 root_env_config: str, 

37 description: str, 

38 opt_targets: Dict[str, Literal["min", "max"]], 

39 ): 

40 super().__init__( 

41 tunables=tunables, 

42 experiment_id=experiment_id, 

43 trial_id=trial_id, 

44 root_env_config=root_env_config, 

45 description=description, 

46 opt_targets=opt_targets, 

47 ) 

48 self._engine = engine 

49 self._schema = schema 

50 

51 def _setup(self) -> None: 

52 super()._setup() 

53 with self._engine.begin() as conn: 

54 # Get git info and the last trial ID for the experiment. 

55 # pylint: disable=not-callable 

56 exp_info = conn.execute( 

57 self._schema.experiment.select() 

58 .with_only_columns( 

59 self._schema.experiment.c.git_repo, 

60 self._schema.experiment.c.git_commit, 

61 self._schema.experiment.c.root_env_config, 

62 func.max(self._schema.trial.c.trial_id).label("trial_id"), 

63 ) 

64 .join( 

65 self._schema.trial, 

66 self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id, 

67 isouter=True, 

68 ) 

69 .where( 

70 self._schema.experiment.c.exp_id == self._experiment_id, 

71 ) 

72 .group_by( 

73 self._schema.experiment.c.git_repo, 

74 self._schema.experiment.c.git_commit, 

75 self._schema.experiment.c.root_env_config, 

76 ) 

77 ).fetchone() 

78 if exp_info is None: 

79 _LOG.info("Start new experiment: %s", self._experiment_id) 

80 # It's a new experiment: create a record for it in the database. 

81 conn.execute( 

82 self._schema.experiment.insert().values( 

83 exp_id=self._experiment_id, 

84 description=self._description, 

85 git_repo=self._git_repo, 

86 git_commit=self._git_commit, 

87 root_env_config=self._root_env_config, 

88 ) 

89 ) 

90 conn.execute( 

91 self._schema.objectives.insert().values( 

92 [ 

93 { 

94 "exp_id": self._experiment_id, 

95 "optimization_target": opt_target, 

96 "optimization_direction": opt_dir, 

97 } 

98 for (opt_target, opt_dir) in self.opt_targets.items() 

99 ] 

100 ) 

101 ) 

102 else: 

103 if exp_info.trial_id is not None: 

104 self._trial_id = exp_info.trial_id + 1 

105 _LOG.info( 

106 "Continue experiment: %s last trial: %s resume from: %d", 

107 self._experiment_id, 

108 exp_info.trial_id, 

109 self._trial_id, 

110 ) 

111 # TODO: Sanity check that certain critical configs (e.g., 

112 # objectives) haven't changed to be incompatible such that a new 

113 # experiment should be started (possibly by prewarming with the 

114 # previous one). 

115 if exp_info.git_commit != self._git_commit: 

116 _LOG.warning( 

117 "Experiment %s git expected: %s %s", 

118 self, 

119 exp_info.git_repo, 

120 exp_info.git_commit, 

121 ) 

122 

123 def merge(self, experiment_ids: List[str]) -> None: 

124 _LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids) 

125 raise NotImplementedError("TODO") 

126 

127 def load_tunable_config(self, config_id: int) -> Dict[str, Any]: 

128 with self._engine.connect() as conn: 

129 return self._get_key_val(conn, self._schema.config_param, "param", config_id=config_id) 

130 

131 def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: 

132 with self._engine.connect() as conn: 

133 cur_telemetry = conn.execute( 

134 self._schema.trial_telemetry.select() 

135 .where( 

136 self._schema.trial_telemetry.c.exp_id == self._experiment_id, 

137 self._schema.trial_telemetry.c.trial_id == trial_id, 

138 ) 

139 .order_by( 

140 self._schema.trial_telemetry.c.ts, 

141 self._schema.trial_telemetry.c.metric_id, 

142 ) 

143 ) 

144 # Not all storage backends store the original zone info. 

145 # We try to ensure data is entered in UTC and augment it on return again here. 

146 return [ 

147 (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) 

148 for row in cur_telemetry.fetchall() 

149 ] 

150 

151 def load( 

152 self, 

153 last_trial_id: int = -1, 

154 ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: 

155 

156 with self._engine.connect() as conn: 

157 cur_trials = conn.execute( 

158 self._schema.trial.select() 

159 .with_only_columns( 

160 self._schema.trial.c.trial_id, 

161 self._schema.trial.c.config_id, 

162 self._schema.trial.c.status, 

163 ) 

164 .where( 

165 self._schema.trial.c.exp_id == self._experiment_id, 

166 self._schema.trial.c.trial_id > last_trial_id, 

167 self._schema.trial.c.status.in_(["SUCCEEDED", "FAILED", "TIMED_OUT"]), 

168 ) 

169 .order_by( 

170 self._schema.trial.c.trial_id.asc(), 

171 ) 

172 ) 

173 

174 trial_ids: List[int] = [] 

175 configs: List[Dict[str, Any]] = [] 

176 scores: List[Optional[Dict[str, Any]]] = [] 

177 status: List[Status] = [] 

178 

179 for trial in cur_trials.fetchall(): 

180 stat = Status[trial.status] 

181 status.append(stat) 

182 trial_ids.append(trial.trial_id) 

183 configs.append( 

184 self._get_key_val( 

185 conn, 

186 self._schema.config_param, 

187 "param", 

188 config_id=trial.config_id, 

189 ) 

190 ) 

191 if stat.is_succeeded(): 

192 scores.append( 

193 self._get_key_val( 

194 conn, 

195 self._schema.trial_result, 

196 "metric", 

197 exp_id=self._experiment_id, 

198 trial_id=trial.trial_id, 

199 ) 

200 ) 

201 else: 

202 scores.append(None) 

203 

204 return (trial_ids, configs, scores, status) 

205 

206 @staticmethod 

207 def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> Dict[str, Any]: 

208 """ 

209 Helper method to retrieve key-value pairs from the database. 

210 

211 (E.g., configurations, results, and telemetry). 

212 """ 

213 cur_result: CursorResult[Tuple[str, Any]] = conn.execute( 

214 select( 

215 column(f"{field}_id"), 

216 column(f"{field}_value"), 

217 ) 

218 .select_from(table) 

219 .where(*[column(key) == val for (key, val) in kwargs.items()]) 

220 ) 

221 # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to 

222 # avoid naming conflicts. 

223 return dict( 

224 row._tuple() for row in cur_result.fetchall() # pylint: disable=protected-access 

225 ) 

226 

227 @staticmethod 

228 def _save_params( 

229 conn: Connection, 

230 table: Table, 

231 params: Dict[str, Any], 

232 **kwargs: Any, 

233 ) -> None: 

234 if not params: 

235 return 

236 conn.execute( 

237 table.insert(), 

238 [ 

239 {**kwargs, "param_id": key, "param_value": nullable(str, val)} 

240 for (key, val) in params.items() 

241 ], 

242 ) 

243 

244 def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]: 

245 timestamp = utcify_timestamp(timestamp, origin="local") 

246 _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp) 

247 if running: 

248 pending_status = ["PENDING", "READY", "RUNNING"] 

249 else: 

250 pending_status = ["PENDING"] 

251 with self._engine.connect() as conn: 

252 cur_trials = conn.execute( 

253 self._schema.trial.select().where( 

254 self._schema.trial.c.exp_id == self._experiment_id, 

255 ( 

256 self._schema.trial.c.ts_start.is_(None) 

257 | (self._schema.trial.c.ts_start <= timestamp) 

258 ), 

259 self._schema.trial.c.ts_end.is_(None), 

260 self._schema.trial.c.status.in_(pending_status), 

261 ) 

262 ) 

263 for trial in cur_trials.fetchall(): 

264 tunables = self._get_key_val( 

265 conn, 

266 self._schema.config_param, 

267 "param", 

268 config_id=trial.config_id, 

269 ) 

270 config = self._get_key_val( 

271 conn, 

272 self._schema.trial_param, 

273 "param", 

274 exp_id=self._experiment_id, 

275 trial_id=trial.trial_id, 

276 ) 

277 yield Trial( 

278 engine=self._engine, 

279 schema=self._schema, 

280 # Reset .is_updated flag after the assignment: 

281 tunables=self._tunables.copy().assign(tunables).reset(), 

282 experiment_id=self._experiment_id, 

283 trial_id=trial.trial_id, 

284 config_id=trial.config_id, 

285 opt_targets=self._opt_targets, 

286 config=config, 

287 ) 

288 

289 def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int: 

290 """ 

291 Get the config ID for the given tunables. 

292 

293 If the config does not exist, create a new record for it. 

294 """ 

295 config_hash = hashlib.sha256(str(tunables).encode("utf-8")).hexdigest() 

296 cur_config = conn.execute( 

297 self._schema.config.select().where(self._schema.config.c.config_hash == config_hash) 

298 ).fetchone() 

299 if cur_config is not None: 

300 return int(cur_config.config_id) # mypy doesn't know it's always int 

301 # Config not found, create a new one: 

302 config_id: int = conn.execute( 

303 self._schema.config.insert().values(config_hash=config_hash) 

304 ).inserted_primary_key[0] 

305 self._save_params( 

306 conn, 

307 self._schema.config_param, 

308 {tunable.name: tunable.value for (tunable, _group) in tunables}, 

309 config_id=config_id, 

310 ) 

311 return config_id 

312 

313 def _new_trial( 

314 self, 

315 tunables: TunableGroups, 

316 ts_start: Optional[datetime] = None, 

317 config: Optional[Dict[str, Any]] = None, 

318 ) -> Storage.Trial: 

319 # MySQL can round microseconds into the future causing scheduler to skip trials. 

320 # Truncate microseconds to avoid this issue. 

321 ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local").replace( 

322 microsecond=0 

323 ) 

324 _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start) 

325 with self._engine.begin() as conn: 

326 try: 

327 config_id = self._get_config_id(conn, tunables) 

328 conn.execute( 

329 self._schema.trial.insert().values( 

330 exp_id=self._experiment_id, 

331 trial_id=self._trial_id, 

332 config_id=config_id, 

333 ts_start=ts_start, 

334 status="PENDING", 

335 ) 

336 ) 

337 

338 # Note: config here is the framework config, not the target 

339 # environment config (i.e., tunables). 

340 if config is not None: 

341 self._save_params( 

342 conn, 

343 self._schema.trial_param, 

344 config, 

345 exp_id=self._experiment_id, 

346 trial_id=self._trial_id, 

347 ) 

348 

349 trial = Trial( 

350 engine=self._engine, 

351 schema=self._schema, 

352 tunables=tunables, 

353 experiment_id=self._experiment_id, 

354 trial_id=self._trial_id, 

355 config_id=config_id, 

356 opt_targets=self._opt_targets, 

357 config=config, 

358 ) 

359 self._trial_id += 1 

360 return trial 

361 except Exception: 

362 conn.rollback() 

363 raise