Coverage for mlos_bench/mlos_bench/storage/sql/common.py: 100%
47 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Common SQL methods for accessing the stored benchmark data."""
7from collections.abc import Mapping
8from typing import Any
10import pandas
11from sqlalchemy import Integer, and_, func, select
12from sqlalchemy.engine import Connection, Engine
13from sqlalchemy.schema import Table
15from mlos_bench.environments.status import Status
16from mlos_bench.storage.base_experiment_data import ExperimentData
17from mlos_bench.storage.base_trial_data import TrialData
18from mlos_bench.storage.sql.schema import DbSchema
19from mlos_bench.util import nullable, utcify_nullable_timestamp, utcify_timestamp
22def save_params(
23 conn: Connection,
24 table: Table,
25 params: Mapping[str, Any],
26 **kwargs: Any,
27) -> None:
28 """
29 Updates a set of (param_id, param_value) tuples in the given Table.
31 Parameters
32 ----------
33 conn : sqlalchemy.engine.Connection
34 A connection to the backend database.
35 table : sqlalchemy.schema.Table
36 The table to update.
37 params : dict[str, Any]
38 The new (param_id, param_value) tuples to upsert to the Table.
39 **kwargs : dict[str, Any]
40 Primary key info for the given table.
41 """
42 if not params:
43 return
44 conn.execute(
45 table.insert(),
46 [
47 {**kwargs, "param_id": key, "param_value": nullable(str, val)}
48 for (key, val) in params.items()
49 ],
50 )
53def get_trials(
54 engine: Engine,
55 schema: DbSchema,
56 experiment_id: str,
57 tunable_config_id: int | None = None,
58) -> dict[int, TrialData]:
59 """
60 Gets :py:class:`~.TrialData` for the given ``experiment_id`` and optionally
61 additionally restricted by ``tunable_config_id``.
63 See Also
64 --------
65 :py:class:`~mlos_bench.storage.sql.tunable_config_trial_group_data.TunableConfigTrialGroupSqlData`
66 :py:class:`~mlos_bench.storage.sql.experiment_data.ExperimentSqlData`
67 """ # pylint: disable=line-too-long # noqa: E501
68 # pylint: disable=import-outside-toplevel,cyclic-import
69 from mlos_bench.storage.sql.trial_data import TrialSqlData
71 with engine.connect() as conn:
72 # Build up sql a statement for fetching trials.
73 stmt = (
74 schema.trial.select()
75 .where(
76 schema.trial.c.exp_id == experiment_id,
77 )
78 .order_by(
79 schema.trial.c.exp_id.asc(),
80 schema.trial.c.trial_id.asc(),
81 )
82 )
83 # Optionally restrict to those using a particular tunable config.
84 if tunable_config_id is not None:
85 stmt = stmt.where(
86 schema.trial.c.config_id == tunable_config_id,
87 )
88 trials = conn.execute(stmt)
89 return {
90 trial.trial_id: TrialSqlData(
91 engine=engine,
92 schema=schema,
93 experiment_id=experiment_id,
94 trial_id=trial.trial_id,
95 config_id=trial.config_id,
96 ts_start=utcify_timestamp(trial.ts_start, origin="utc"),
97 ts_end=utcify_nullable_timestamp(trial.ts_end, origin="utc"),
98 status=Status[trial.status],
99 trial_runner_id=trial.trial_runner_id,
100 )
101 for trial in trials.fetchall()
102 }
105def get_results_df(
106 engine: Engine,
107 schema: DbSchema,
108 experiment_id: str,
109 tunable_config_id: int | None = None,
110) -> pandas.DataFrame:
111 """
112 Gets TrialData for the given experiment_id and optionally additionally restricted by
113 tunable_config_id.
115 The returned DataFrame includes each trial's metadata, config, and results in
116 wide format, with config parameters prefixed with
117 :py:attr:`.ExperimentData.CONFIG_COLUMN_PREFIX` and results prefixed with
118 :py:attr:`.ExperimentData.RESULT_COLUMN_PREFIX`.
120 See Also
121 --------
122 :py:class:`~mlos_bench.storage.sql.tunable_config_trial_group_data.TunableConfigTrialGroupSqlData`
123 :py:class:`~mlos_bench.storage.sql.experiment_data.ExperimentSqlData`
124 """ # pylint: disable=line-too-long # noqa: E501
125 # pylint: disable=too-many-locals
126 with engine.connect() as conn:
127 # Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config.
128 tunable_config_group_id_stmt = (
129 schema.trial.select()
130 .with_only_columns(
131 schema.trial.c.exp_id,
132 schema.trial.c.config_id,
133 func.min(schema.trial.c.trial_id)
134 .cast(Integer)
135 .label("tunable_config_trial_group_id"),
136 )
137 .where(
138 schema.trial.c.exp_id == experiment_id,
139 )
140 .group_by(
141 schema.trial.c.exp_id,
142 schema.trial.c.config_id,
143 )
144 )
145 # Optionally restrict to those using a particular tunable config.
146 if tunable_config_id is not None:
147 tunable_config_group_id_stmt = tunable_config_group_id_stmt.where(
148 schema.trial.c.config_id == tunable_config_id,
149 )
150 tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery()
152 # Get each trial's metadata.
153 cur_trials_stmt = (
154 select(
155 schema.trial,
156 tunable_config_trial_group_id_subquery,
157 )
158 .where(
159 schema.trial.c.exp_id == experiment_id,
160 and_(
161 tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id,
162 tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id,
163 ),
164 )
165 .order_by(
166 schema.trial.c.exp_id.asc(),
167 schema.trial.c.trial_id.asc(),
168 )
169 )
170 # Optionally restrict to those using a particular tunable config.
171 if tunable_config_id is not None:
172 cur_trials_stmt = cur_trials_stmt.where(
173 schema.trial.c.config_id == tunable_config_id,
174 )
175 cur_trials = conn.execute(cur_trials_stmt)
176 trials_df = pandas.DataFrame(
177 [
178 (
179 row.trial_id,
180 utcify_timestamp(row.ts_start, origin="utc"),
181 utcify_nullable_timestamp(row.ts_end, origin="utc"),
182 row.config_id,
183 row.tunable_config_trial_group_id,
184 row.status,
185 row.trial_runner_id,
186 )
187 for row in cur_trials.fetchall()
188 ],
189 columns=[
190 "trial_id",
191 "ts_start",
192 "ts_end",
193 "tunable_config_id",
194 "tunable_config_trial_group_id",
195 "status",
196 "trial_runner_id",
197 ],
198 )
200 # Get each trial's config in wide format.
201 configs_stmt = (
202 schema.trial.select()
203 .with_only_columns(
204 schema.trial.c.trial_id,
205 schema.trial.c.config_id,
206 schema.config_param.c.param_id,
207 schema.config_param.c.param_value,
208 )
209 .where(
210 schema.trial.c.exp_id == experiment_id,
211 )
212 .join(
213 schema.config_param,
214 schema.config_param.c.config_id == schema.trial.c.config_id,
215 )
216 .order_by(
217 schema.trial.c.trial_id,
218 schema.config_param.c.param_id,
219 )
220 )
221 if tunable_config_id is not None:
222 configs_stmt = configs_stmt.where(
223 schema.trial.c.config_id == tunable_config_id,
224 )
225 configs = conn.execute(configs_stmt)
226 configs_df = pandas.DataFrame(
227 [
228 (
229 row.trial_id,
230 row.config_id,
231 ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id,
232 row.param_value,
233 )
234 for row in configs.fetchall()
235 ],
236 columns=["trial_id", "tunable_config_id", "param", "value"],
237 ).pivot(
238 index=["trial_id", "tunable_config_id"],
239 columns="param",
240 values="value",
241 )
242 configs_df = configs_df.apply(
243 pandas.to_numeric,
244 errors="coerce",
245 ).fillna(configs_df)
247 # Get each trial's results in wide format.
248 results_stmt = (
249 schema.trial_result.select()
250 .with_only_columns(
251 schema.trial_result.c.trial_id,
252 schema.trial_result.c.metric_id,
253 schema.trial_result.c.metric_value,
254 )
255 .where(
256 schema.trial_result.c.exp_id == experiment_id,
257 )
258 .order_by(
259 schema.trial_result.c.trial_id,
260 schema.trial_result.c.metric_id,
261 )
262 )
263 if tunable_config_id is not None:
264 results_stmt = results_stmt.join(
265 schema.trial,
266 and_(
267 schema.trial.c.exp_id == schema.trial_result.c.exp_id,
268 schema.trial.c.trial_id == schema.trial_result.c.trial_id,
269 schema.trial.c.config_id == tunable_config_id,
270 ),
271 )
272 results = conn.execute(results_stmt)
273 results_df = pandas.DataFrame(
274 [
275 (
276 row.trial_id,
277 ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id,
278 row.metric_value,
279 )
280 for row in results.fetchall()
281 ],
282 columns=["trial_id", "metric", "value"],
283 ).pivot(
284 index="trial_id",
285 columns="metric",
286 values="value",
287 )
288 results_df = results_df.apply(
289 pandas.to_numeric,
290 errors="coerce",
291 ).fillna(results_df)
293 # Concat the trials, configs, and results.
294 return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left").merge(
295 results_df,
296 on="trial_id",
297 how="left",
298 )