Coverage for mlos_bench/mlos_bench/storage/sql/common.py: 100%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Common SQL methods for accessing the stored benchmark data.""" 

6 

7from collections.abc import Mapping 

8from typing import Any 

9 

10import pandas 

11from sqlalchemy import Integer, and_, func, select 

12from sqlalchemy.engine import Connection, Engine 

13from sqlalchemy.schema import Table 

14 

15from mlos_bench.environments.status import Status 

16from mlos_bench.storage.base_experiment_data import ExperimentData 

17from mlos_bench.storage.base_trial_data import TrialData 

18from mlos_bench.storage.sql.schema import DbSchema 

19from mlos_bench.util import nullable, utcify_nullable_timestamp, utcify_timestamp 

20 

21 

22def save_params( 

23 conn: Connection, 

24 table: Table, 

25 params: Mapping[str, Any], 

26 **kwargs: Any, 

27) -> None: 

28 """ 

29 Updates a set of (param_id, param_value) tuples in the given Table. 

30 

31 Parameters 

32 ---------- 

33 conn : sqlalchemy.engine.Connection 

34 A connection to the backend database. 

35 table : sqlalchemy.schema.Table 

36 The table to update. 

37 params : dict[str, Any] 

38 The new (param_id, param_value) tuples to upsert to the Table. 

39 **kwargs : dict[str, Any] 

40 Primary key info for the given table. 

41 """ 

42 if not params: 

43 return 

44 conn.execute( 

45 table.insert(), 

46 [ 

47 {**kwargs, "param_id": key, "param_value": nullable(str, val)} 

48 for (key, val) in params.items() 

49 ], 

50 ) 

51 

52 

53def get_trials( 

54 engine: Engine, 

55 schema: DbSchema, 

56 experiment_id: str, 

57 tunable_config_id: int | None = None, 

58) -> dict[int, TrialData]: 

59 """ 

60 Gets :py:class:`~.TrialData` for the given ``experiment_id`` and optionally 

61 additionally restricted by ``tunable_config_id``. 

62 

63 See Also 

64 -------- 

65 :py:class:`~mlos_bench.storage.sql.tunable_config_trial_group_data.TunableConfigTrialGroupSqlData` 

66 :py:class:`~mlos_bench.storage.sql.experiment_data.ExperimentSqlData` 

67 """ # pylint: disable=line-too-long # noqa: E501 

68 # pylint: disable=import-outside-toplevel,cyclic-import 

69 from mlos_bench.storage.sql.trial_data import TrialSqlData 

70 

71 with engine.connect() as conn: 

72 # Build up sql a statement for fetching trials. 

73 stmt = ( 

74 schema.trial.select() 

75 .where( 

76 schema.trial.c.exp_id == experiment_id, 

77 ) 

78 .order_by( 

79 schema.trial.c.exp_id.asc(), 

80 schema.trial.c.trial_id.asc(), 

81 ) 

82 ) 

83 # Optionally restrict to those using a particular tunable config. 

84 if tunable_config_id is not None: 

85 stmt = stmt.where( 

86 schema.trial.c.config_id == tunable_config_id, 

87 ) 

88 trials = conn.execute(stmt) 

89 return { 

90 trial.trial_id: TrialSqlData( 

91 engine=engine, 

92 schema=schema, 

93 experiment_id=experiment_id, 

94 trial_id=trial.trial_id, 

95 config_id=trial.config_id, 

96 ts_start=utcify_timestamp(trial.ts_start, origin="utc"), 

97 ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"), 

98 status=Status[trial.status], 

99 trial_runner_id=trial.trial_runner_id, 

100 ) 

101 for trial in trials.fetchall() 

102 } 

103 

104 

105def get_results_df( 

106 engine: Engine, 

107 schema: DbSchema, 

108 experiment_id: str, 

109 tunable_config_id: int | None = None, 

110) -> pandas.DataFrame: 

111 """ 

112 Gets TrialData for the given experiment_id and optionally additionally restricted by 

113 tunable_config_id. 

114 

115 The returned DataFrame includes each trial's metadata, config, and results in 

116 wide format, with config parameters prefixed with 

117 :py:attr:`.ExperimentData.CONFIG_COLUMN_PREFIX` and results prefixed with 

118 :py:attr:`.ExperimentData.RESULT_COLUMN_PREFIX`. 

119 

120 See Also 

121 -------- 

122 :py:class:`~mlos_bench.storage.sql.tunable_config_trial_group_data.TunableConfigTrialGroupSqlData` 

123 :py:class:`~mlos_bench.storage.sql.experiment_data.ExperimentSqlData` 

124 """ # pylint: disable=line-too-long # noqa: E501 

125 # pylint: disable=too-many-locals 

126 with engine.connect() as conn: 

127 # Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config. 

128 tunable_config_group_id_stmt = ( 

129 schema.trial.select() 

130 .with_only_columns( 

131 schema.trial.c.exp_id, 

132 schema.trial.c.config_id, 

133 func.min(schema.trial.c.trial_id) 

134 .cast(Integer) 

135 .label("tunable_config_trial_group_id"), 

136 ) 

137 .where( 

138 schema.trial.c.exp_id == experiment_id, 

139 ) 

140 .group_by( 

141 schema.trial.c.exp_id, 

142 schema.trial.c.config_id, 

143 ) 

144 ) 

145 # Optionally restrict to those using a particular tunable config. 

146 if tunable_config_id is not None: 

147 tunable_config_group_id_stmt = tunable_config_group_id_stmt.where( 

148 schema.trial.c.config_id == tunable_config_id, 

149 ) 

150 tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery() 

151 

152 # Get each trial's metadata. 

153 cur_trials_stmt = ( 

154 select( 

155 schema.trial, 

156 tunable_config_trial_group_id_subquery, 

157 ) 

158 .where( 

159 schema.trial.c.exp_id == experiment_id, 

160 and_( 

161 tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id, 

162 tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id, 

163 ), 

164 ) 

165 .order_by( 

166 schema.trial.c.exp_id.asc(), 

167 schema.trial.c.trial_id.asc(), 

168 ) 

169 ) 

170 # Optionally restrict to those using a particular tunable config. 

171 if tunable_config_id is not None: 

172 cur_trials_stmt = cur_trials_stmt.where( 

173 schema.trial.c.config_id == tunable_config_id, 

174 ) 

175 cur_trials = conn.execute(cur_trials_stmt) 

176 trials_df = pandas.DataFrame( 

177 [ 

178 ( 

179 row.trial_id, 

180 utcify_timestamp(row.ts_start, origin="utc"), 

181 utcify_nullable_timestamp(row.ts_end, origin="utc"), 

182 row.config_id, 

183 row.tunable_config_trial_group_id, 

184 row.status, 

185 row.trial_runner_id, 

186 ) 

187 for row in cur_trials.fetchall() 

188 ], 

189 columns=[ 

190 "trial_id", 

191 "ts_start", 

192 "ts_end", 

193 "tunable_config_id", 

194 "tunable_config_trial_group_id", 

195 "status", 

196 "trial_runner_id", 

197 ], 

198 ) 

199 

200 # Get each trial's config in wide format. 

201 configs_stmt = ( 

202 schema.trial.select() 

203 .with_only_columns( 

204 schema.trial.c.trial_id, 

205 schema.trial.c.config_id, 

206 schema.config_param.c.param_id, 

207 schema.config_param.c.param_value, 

208 ) 

209 .where( 

210 schema.trial.c.exp_id == experiment_id, 

211 ) 

212 .join( 

213 schema.config_param, 

214 schema.config_param.c.config_id == schema.trial.c.config_id, 

215 ) 

216 .order_by( 

217 schema.trial.c.trial_id, 

218 schema.config_param.c.param_id, 

219 ) 

220 ) 

221 if tunable_config_id is not None: 

222 configs_stmt = configs_stmt.where( 

223 schema.trial.c.config_id == tunable_config_id, 

224 ) 

225 configs = conn.execute(configs_stmt) 

226 configs_df = pandas.DataFrame( 

227 [ 

228 ( 

229 row.trial_id, 

230 row.config_id, 

231 ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, 

232 row.param_value, 

233 ) 

234 for row in configs.fetchall() 

235 ], 

236 columns=["trial_id", "tunable_config_id", "param", "value"], 

237 ).pivot( 

238 index=["trial_id", "tunable_config_id"], 

239 columns="param", 

240 values="value", 

241 ) 

242 configs_df = configs_df.apply( 

243 pandas.to_numeric, 

244 errors="coerce", 

245 ).fillna(configs_df) 

246 

247 # Get each trial's results in wide format. 

248 results_stmt = ( 

249 schema.trial_result.select() 

250 .with_only_columns( 

251 schema.trial_result.c.trial_id, 

252 schema.trial_result.c.metric_id, 

253 schema.trial_result.c.metric_value, 

254 ) 

255 .where( 

256 schema.trial_result.c.exp_id == experiment_id, 

257 ) 

258 .order_by( 

259 schema.trial_result.c.trial_id, 

260 schema.trial_result.c.metric_id, 

261 ) 

262 ) 

263 if tunable_config_id is not None: 

264 results_stmt = results_stmt.join( 

265 schema.trial, 

266 and_( 

267 schema.trial.c.exp_id == schema.trial_result.c.exp_id, 

268 schema.trial.c.trial_id == schema.trial_result.c.trial_id, 

269 schema.trial.c.config_id == tunable_config_id, 

270 ), 

271 ) 

272 results = conn.execute(results_stmt) 

273 results_df = pandas.DataFrame( 

274 [ 

275 ( 

276 row.trial_id, 

277 ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, 

278 row.metric_value, 

279 ) 

280 for row in results.fetchall() 

281 ], 

282 columns=["trial_id", "metric", "value"], 

283 ).pivot( 

284 index="trial_id", 

285 columns="metric", 

286 values="value", 

287 ) 

288 results_df = results_df.apply( 

289 pandas.to_numeric, 

290 errors="coerce", 

291 ).fillna(results_df) 

292 

293 # Concat the trials, configs, and results. 

294 return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left").merge( 

295 results_df, 

296 on="trial_id", 

297 how="left", 

298 )