Coverage for mlos_bench/mlos_bench/storage/sql/common.py: 100%
40 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Common SQL methods for accessing the stored benchmark data."""
6from typing import Dict, Optional
8import pandas
9from sqlalchemy import Engine, Integer, and_, func, select
11from mlos_bench.environments.status import Status
12from mlos_bench.storage.base_experiment_data import ExperimentData
13from mlos_bench.storage.base_trial_data import TrialData
14from mlos_bench.storage.sql.schema import DbSchema
15from mlos_bench.util import utcify_nullable_timestamp, utcify_timestamp
18def get_trials(
19 engine: Engine,
20 schema: DbSchema,
21 experiment_id: str,
22 tunable_config_id: Optional[int] = None,
23) -> Dict[int, TrialData]:
24 """
25 Gets TrialData for the given experiment_data and optionally additionally restricted
26 by tunable_config_id.
28 Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData.
29 """
30 # pylint: disable=import-outside-toplevel,cyclic-import
31 from mlos_bench.storage.sql.trial_data import TrialSqlData
33 with engine.connect() as conn:
34 # Build up sql a statement for fetching trials.
35 stmt = (
36 schema.trial.select()
37 .where(
38 schema.trial.c.exp_id == experiment_id,
39 )
40 .order_by(
41 schema.trial.c.exp_id.asc(),
42 schema.trial.c.trial_id.asc(),
43 )
44 )
45 # Optionally restrict to those using a particular tunable config.
46 if tunable_config_id is not None:
47 stmt = stmt.where(
48 schema.trial.c.config_id == tunable_config_id,
49 )
50 trials = conn.execute(stmt)
51 return {
52 trial.trial_id: TrialSqlData(
53 engine=engine,
54 schema=schema,
55 experiment_id=experiment_id,
56 trial_id=trial.trial_id,
57 config_id=trial.config_id,
58 ts_start=utcify_timestamp(trial.ts_start, origin="utc"),
59 ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"),
60 status=Status[trial.status],
61 )
62 for trial in trials.fetchall()
63 }
66def get_results_df(
67 engine: Engine,
68 schema: DbSchema,
69 experiment_id: str,
70 tunable_config_id: Optional[int] = None,
71) -> pandas.DataFrame:
72 """
73 Gets TrialData for the given experiment_data and optionally additionally restricted
74 by tunable_config_id.
76 Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData.
77 """
78 # pylint: disable=too-many-locals
79 with engine.connect() as conn:
80 # Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config.
81 tunable_config_group_id_stmt = (
82 schema.trial.select()
83 .with_only_columns(
84 schema.trial.c.exp_id,
85 schema.trial.c.config_id,
86 func.min(schema.trial.c.trial_id)
87 .cast(Integer)
88 .label("tunable_config_trial_group_id"),
89 )
90 .where(
91 schema.trial.c.exp_id == experiment_id,
92 )
93 .group_by(
94 schema.trial.c.exp_id,
95 schema.trial.c.config_id,
96 )
97 )
98 # Optionally restrict to those using a particular tunable config.
99 if tunable_config_id is not None:
100 tunable_config_group_id_stmt = tunable_config_group_id_stmt.where(
101 schema.trial.c.config_id == tunable_config_id,
102 )
103 tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery()
105 # Get each trial's metadata.
106 cur_trials_stmt = (
107 select(
108 schema.trial,
109 tunable_config_trial_group_id_subquery,
110 )
111 .where(
112 schema.trial.c.exp_id == experiment_id,
113 and_(
114 tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id,
115 tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id,
116 ),
117 )
118 .order_by(
119 schema.trial.c.exp_id.asc(),
120 schema.trial.c.trial_id.asc(),
121 )
122 )
123 # Optionally restrict to those using a particular tunable config.
124 if tunable_config_id is not None:
125 cur_trials_stmt = cur_trials_stmt.where(
126 schema.trial.c.config_id == tunable_config_id,
127 )
128 cur_trials = conn.execute(cur_trials_stmt)
129 trials_df = pandas.DataFrame(
130 [
131 (
132 row.trial_id,
133 utcify_timestamp(row.ts_start, origin="utc"),
134 utcify_nullable_timestamp(row.ts_end, origin="utc"),
135 row.config_id,
136 row.tunable_config_trial_group_id,
137 row.status,
138 )
139 for row in cur_trials.fetchall()
140 ],
141 columns=[
142 "trial_id",
143 "ts_start",
144 "ts_end",
145 "tunable_config_id",
146 "tunable_config_trial_group_id",
147 "status",
148 ],
149 )
151 # Get each trial's config in wide format.
152 configs_stmt = (
153 schema.trial.select()
154 .with_only_columns(
155 schema.trial.c.trial_id,
156 schema.trial.c.config_id,
157 schema.config_param.c.param_id,
158 schema.config_param.c.param_value,
159 )
160 .where(
161 schema.trial.c.exp_id == experiment_id,
162 )
163 .join(
164 schema.config_param,
165 schema.config_param.c.config_id == schema.trial.c.config_id,
166 isouter=True,
167 )
168 .order_by(
169 schema.trial.c.trial_id,
170 schema.config_param.c.param_id,
171 )
172 )
173 if tunable_config_id is not None:
174 configs_stmt = configs_stmt.where(
175 schema.trial.c.config_id == tunable_config_id,
176 )
177 configs = conn.execute(configs_stmt)
178 configs_df = pandas.DataFrame(
179 [
180 (
181 row.trial_id,
182 row.config_id,
183 ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id,
184 row.param_value,
185 )
186 for row in configs.fetchall()
187 ],
188 columns=["trial_id", "tunable_config_id", "param", "value"],
189 ).pivot(
190 index=["trial_id", "tunable_config_id"],
191 columns="param",
192 values="value",
193 )
194 configs_df = configs_df.apply( # type: ignore[assignment] # (fp)
195 pandas.to_numeric,
196 errors="coerce",
197 ).fillna(configs_df)
199 # Get each trial's results in wide format.
200 results_stmt = (
201 schema.trial_result.select()
202 .with_only_columns(
203 schema.trial_result.c.trial_id,
204 schema.trial_result.c.metric_id,
205 schema.trial_result.c.metric_value,
206 )
207 .where(
208 schema.trial_result.c.exp_id == experiment_id,
209 )
210 .order_by(
211 schema.trial_result.c.trial_id,
212 schema.trial_result.c.metric_id,
213 )
214 )
215 if tunable_config_id is not None:
216 results_stmt = results_stmt.join(
217 schema.trial,
218 and_(
219 schema.trial.c.exp_id == schema.trial_result.c.exp_id,
220 schema.trial.c.trial_id == schema.trial_result.c.trial_id,
221 schema.trial.c.config_id == tunable_config_id,
222 ),
223 )
224 results = conn.execute(results_stmt)
225 results_df = pandas.DataFrame(
226 [
227 (
228 row.trial_id,
229 ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id,
230 row.metric_value,
231 )
232 for row in results.fetchall()
233 ],
234 columns=["trial_id", "metric", "value"],
235 ).pivot(
236 index="trial_id",
237 columns="metric",
238 values="value",
239 )
240 results_df = results_df.apply( # type: ignore[assignment] # (fp)
241 pandas.to_numeric,
242 errors="coerce",
243 ).fillna(results_df)
245 # Concat the trials, configs, and results.
246 return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left").merge(
247 results_df,
248 on="trial_id",
249 how="left",
250 )