Coverage for mlos_bench/mlos_bench/storage/sql/common.py: 100%

40 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Common SQL methods for accessing the stored benchmark data.""" 

6from typing import Dict, Optional 

7 

8import pandas 

9from sqlalchemy import Engine, Integer, and_, func, select 

10 

11from mlos_bench.environments.status import Status 

12from mlos_bench.storage.base_experiment_data import ExperimentData 

13from mlos_bench.storage.base_trial_data import TrialData 

14from mlos_bench.storage.sql.schema import DbSchema 

15from mlos_bench.util import utcify_nullable_timestamp, utcify_timestamp 

16 

17 

18def get_trials( 

19 engine: Engine, 

20 schema: DbSchema, 

21 experiment_id: str, 

22 tunable_config_id: Optional[int] = None, 

23) -> Dict[int, TrialData]: 

24 """ 

25 Gets TrialData for the given experiment_data and optionally additionally restricted 

26 by tunable_config_id. 

27 

28 Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData. 

29 """ 

30 # pylint: disable=import-outside-toplevel,cyclic-import 

31 from mlos_bench.storage.sql.trial_data import TrialSqlData 

32 

33 with engine.connect() as conn: 

34 # Build up sql a statement for fetching trials. 

35 stmt = ( 

36 schema.trial.select() 

37 .where( 

38 schema.trial.c.exp_id == experiment_id, 

39 ) 

40 .order_by( 

41 schema.trial.c.exp_id.asc(), 

42 schema.trial.c.trial_id.asc(), 

43 ) 

44 ) 

45 # Optionally restrict to those using a particular tunable config. 

46 if tunable_config_id is not None: 

47 stmt = stmt.where( 

48 schema.trial.c.config_id == tunable_config_id, 

49 ) 

50 trials = conn.execute(stmt) 

51 return { 

52 trial.trial_id: TrialSqlData( 

53 engine=engine, 

54 schema=schema, 

55 experiment_id=experiment_id, 

56 trial_id=trial.trial_id, 

57 config_id=trial.config_id, 

58 ts_start=utcify_timestamp(trial.ts_start, origin="utc"), 

59 ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"), 

60 status=Status[trial.status], 

61 ) 

62 for trial in trials.fetchall() 

63 } 

64 

65 

66def get_results_df( 

67 engine: Engine, 

68 schema: DbSchema, 

69 experiment_id: str, 

70 tunable_config_id: Optional[int] = None, 

71) -> pandas.DataFrame: 

72 """ 

73 Gets TrialData for the given experiment_data and optionally additionally restricted 

74 by tunable_config_id. 

75 

76 Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData. 

77 """ 

78 # pylint: disable=too-many-locals 

79 with engine.connect() as conn: 

80 # Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config. 

81 tunable_config_group_id_stmt = ( 

82 schema.trial.select() 

83 .with_only_columns( 

84 schema.trial.c.exp_id, 

85 schema.trial.c.config_id, 

86 func.min(schema.trial.c.trial_id) 

87 .cast(Integer) 

88 .label("tunable_config_trial_group_id"), 

89 ) 

90 .where( 

91 schema.trial.c.exp_id == experiment_id, 

92 ) 

93 .group_by( 

94 schema.trial.c.exp_id, 

95 schema.trial.c.config_id, 

96 ) 

97 ) 

98 # Optionally restrict to those using a particular tunable config. 

99 if tunable_config_id is not None: 

100 tunable_config_group_id_stmt = tunable_config_group_id_stmt.where( 

101 schema.trial.c.config_id == tunable_config_id, 

102 ) 

103 tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery() 

104 

105 # Get each trial's metadata. 

106 cur_trials_stmt = ( 

107 select( 

108 schema.trial, 

109 tunable_config_trial_group_id_subquery, 

110 ) 

111 .where( 

112 schema.trial.c.exp_id == experiment_id, 

113 and_( 

114 tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id, 

115 tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id, 

116 ), 

117 ) 

118 .order_by( 

119 schema.trial.c.exp_id.asc(), 

120 schema.trial.c.trial_id.asc(), 

121 ) 

122 ) 

123 # Optionally restrict to those using a particular tunable config. 

124 if tunable_config_id is not None: 

125 cur_trials_stmt = cur_trials_stmt.where( 

126 schema.trial.c.config_id == tunable_config_id, 

127 ) 

128 cur_trials = conn.execute(cur_trials_stmt) 

129 trials_df = pandas.DataFrame( 

130 [ 

131 ( 

132 row.trial_id, 

133 utcify_timestamp(row.ts_start, origin="utc"), 

134 utcify_nullable_timestamp(row.ts_end, origin="utc"), 

135 row.config_id, 

136 row.tunable_config_trial_group_id, 

137 row.status, 

138 ) 

139 for row in cur_trials.fetchall() 

140 ], 

141 columns=[ 

142 "trial_id", 

143 "ts_start", 

144 "ts_end", 

145 "tunable_config_id", 

146 "tunable_config_trial_group_id", 

147 "status", 

148 ], 

149 ) 

150 

151 # Get each trial's config in wide format. 

152 configs_stmt = ( 

153 schema.trial.select() 

154 .with_only_columns( 

155 schema.trial.c.trial_id, 

156 schema.trial.c.config_id, 

157 schema.config_param.c.param_id, 

158 schema.config_param.c.param_value, 

159 ) 

160 .where( 

161 schema.trial.c.exp_id == experiment_id, 

162 ) 

163 .join( 

164 schema.config_param, 

165 schema.config_param.c.config_id == schema.trial.c.config_id, 

166 isouter=True, 

167 ) 

168 .order_by( 

169 schema.trial.c.trial_id, 

170 schema.config_param.c.param_id, 

171 ) 

172 ) 

173 if tunable_config_id is not None: 

174 configs_stmt = configs_stmt.where( 

175 schema.trial.c.config_id == tunable_config_id, 

176 ) 

177 configs = conn.execute(configs_stmt) 

178 configs_df = pandas.DataFrame( 

179 [ 

180 ( 

181 row.trial_id, 

182 row.config_id, 

183 ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, 

184 row.param_value, 

185 ) 

186 for row in configs.fetchall() 

187 ], 

188 columns=["trial_id", "tunable_config_id", "param", "value"], 

189 ).pivot( 

190 index=["trial_id", "tunable_config_id"], 

191 columns="param", 

192 values="value", 

193 ) 

194 configs_df = configs_df.apply( # type: ignore[assignment] # (fp) 

195 pandas.to_numeric, 

196 errors="coerce", 

197 ).fillna(configs_df) 

198 

199 # Get each trial's results in wide format. 

200 results_stmt = ( 

201 schema.trial_result.select() 

202 .with_only_columns( 

203 schema.trial_result.c.trial_id, 

204 schema.trial_result.c.metric_id, 

205 schema.trial_result.c.metric_value, 

206 ) 

207 .where( 

208 schema.trial_result.c.exp_id == experiment_id, 

209 ) 

210 .order_by( 

211 schema.trial_result.c.trial_id, 

212 schema.trial_result.c.metric_id, 

213 ) 

214 ) 

215 if tunable_config_id is not None: 

216 results_stmt = results_stmt.join( 

217 schema.trial, 

218 and_( 

219 schema.trial.c.exp_id == schema.trial_result.c.exp_id, 

220 schema.trial.c.trial_id == schema.trial_result.c.trial_id, 

221 schema.trial.c.config_id == tunable_config_id, 

222 ), 

223 ) 

224 results = conn.execute(results_stmt) 

225 results_df = pandas.DataFrame( 

226 [ 

227 ( 

228 row.trial_id, 

229 ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, 

230 row.metric_value, 

231 ) 

232 for row in results.fetchall() 

233 ], 

234 columns=["trial_id", "metric", "value"], 

235 ).pivot( 

236 index="trial_id", 

237 columns="metric", 

238 values="value", 

239 ) 

240 results_df = results_df.apply( # type: ignore[assignment] # (fp) 

241 pandas.to_numeric, 

242 errors="coerce", 

243 ).fillna(results_df) 

244 

245 # Concat the trials, configs, and results. 

246 return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left").merge( 

247 results_df, 

248 on="trial_id", 

249 how="left", 

250 )