Coverage for mlos_bench/mlos_bench/storage/sql/experiment_data.py: 96%
51 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""An interface to access the experiment benchmark data stored in SQL DB."""
6import logging
7from typing import Dict, Literal, Optional
9import pandas
10from sqlalchemy import Engine, Integer, String, func
12from mlos_bench.storage.base_experiment_data import ExperimentData
13from mlos_bench.storage.base_trial_data import TrialData
14from mlos_bench.storage.base_tunable_config_data import TunableConfigData
15from mlos_bench.storage.base_tunable_config_trial_group_data import (
16 TunableConfigTrialGroupData,
17)
18from mlos_bench.storage.sql import common
19from mlos_bench.storage.sql.schema import DbSchema
20from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData
21from mlos_bench.storage.sql.tunable_config_trial_group_data import (
22 TunableConfigTrialGroupSqlData,
23)
25_LOG = logging.getLogger(__name__)
28class ExperimentSqlData(ExperimentData):
29 """
30 SQL interface for accessing the stored experiment benchmark data.
32 An experiment groups together a set of trials that are run with a given set of
33 scripts and mlos_bench configuration files.
34 """
36 def __init__( # pylint: disable=too-many-arguments
37 self,
38 *,
39 engine: Engine,
40 schema: DbSchema,
41 experiment_id: str,
42 description: str,
43 root_env_config: str,
44 git_repo: str,
45 git_commit: str,
46 ):
47 super().__init__(
48 experiment_id=experiment_id,
49 description=description,
50 root_env_config=root_env_config,
51 git_repo=git_repo,
52 git_commit=git_commit,
53 )
54 self._engine = engine
55 self._schema = schema
57 @property
58 def objectives(self) -> Dict[str, Literal["min", "max"]]:
59 with self._engine.connect() as conn:
60 objectives_db_data = conn.execute(
61 self._schema.objectives.select()
62 .where(
63 self._schema.objectives.c.exp_id == self._experiment_id,
64 )
65 .order_by(
66 self._schema.objectives.c.weight.desc(),
67 self._schema.objectives.c.optimization_target.asc(),
68 )
69 )
70 return {
71 objective.optimization_target: objective.optimization_direction
72 for objective in objectives_db_data.fetchall()
73 }
75 # TODO: provide a way to get individual data to avoid repeated bulk fetches
76 # where only small amounts of data is accessed.
77 # Or else make the TrialData object lazily populate.
79 @property
80 def trials(self) -> Dict[int, TrialData]:
81 return common.get_trials(self._engine, self._schema, self._experiment_id)
83 @property
84 def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]:
85 with self._engine.connect() as conn:
86 tunable_config_trial_groups = conn.execute(
87 self._schema.trial.select()
88 .with_only_columns(
89 self._schema.trial.c.config_id,
90 func.min(self._schema.trial.c.trial_id)
91 .cast(Integer)
92 .label("tunable_config_trial_group_id"), # pylint: disable=not-callable
93 )
94 .where(
95 self._schema.trial.c.exp_id == self._experiment_id,
96 )
97 .group_by(
98 self._schema.trial.c.exp_id,
99 self._schema.trial.c.config_id,
100 )
101 )
102 return {
103 tunable_config_trial_group.config_id: TunableConfigTrialGroupSqlData(
104 engine=self._engine,
105 schema=self._schema,
106 experiment_id=self._experiment_id,
107 tunable_config_id=tunable_config_trial_group.config_id,
108 tunable_config_trial_group_id=tunable_config_trial_group.tunable_config_trial_group_id, # pylint:disable=line-too-long # noqa
109 )
110 for tunable_config_trial_group in tunable_config_trial_groups.fetchall()
111 }
113 @property
114 def tunable_configs(self) -> Dict[int, TunableConfigData]:
115 with self._engine.connect() as conn:
116 tunable_configs = conn.execute(
117 self._schema.trial.select()
118 .with_only_columns(
119 self._schema.trial.c.config_id.cast(Integer).label("config_id"),
120 )
121 .where(
122 self._schema.trial.c.exp_id == self._experiment_id,
123 )
124 .group_by(
125 self._schema.trial.c.exp_id,
126 self._schema.trial.c.config_id,
127 )
128 )
129 return {
130 tunable_config.config_id: TunableConfigSqlData(
131 engine=self._engine,
132 schema=self._schema,
133 tunable_config_id=tunable_config.config_id,
134 )
135 for tunable_config in tunable_configs.fetchall()
136 }
138 @property
139 def default_tunable_config_id(self) -> Optional[int]:
140 """
141 Retrieves the (tunable) config id for the default tunable values for this
142 experiment.
144 Note: this is by *default* the first trial executed for this experiment.
145 However, it is currently possible that the user changed the tunables config
146 in between resumptions of an experiment.
148 Returns
149 -------
150 int
151 """
152 with self._engine.connect() as conn:
153 query_results = conn.execute(
154 self._schema.trial.select()
155 .with_only_columns(
156 self._schema.trial.c.config_id.cast(Integer).label("config_id"),
157 )
158 .where(
159 self._schema.trial.c.exp_id == self._experiment_id,
160 self._schema.trial.c.trial_id.in_(
161 self._schema.trial_param.select()
162 .with_only_columns(
163 func.min(self._schema.trial_param.c.trial_id)
164 .cast(Integer)
165 .label("first_trial_id_with_defaults"), # pylint: disable=not-callable
166 )
167 .where(
168 self._schema.trial_param.c.exp_id == self._experiment_id,
169 self._schema.trial_param.c.param_id == "is_defaults",
170 func.lower(self._schema.trial_param.c.param_value, type_=String).in_(
171 ["1", "true"]
172 ),
173 )
174 .scalar_subquery()
175 ),
176 )
177 )
178 min_default_trial_row = query_results.fetchone()
179 if min_default_trial_row is not None:
180 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy
181 return min_default_trial_row._tuple()[0]
182 # fallback logic - assume minimum trial_id for experiment
183 query_results = conn.execute(
184 self._schema.trial.select()
185 .with_only_columns(
186 self._schema.trial.c.config_id.cast(Integer).label("config_id"),
187 )
188 .where(
189 self._schema.trial.c.exp_id == self._experiment_id,
190 self._schema.trial.c.trial_id.in_(
191 self._schema.trial.select()
192 .with_only_columns(
193 func.min(self._schema.trial.c.trial_id)
194 .cast(Integer)
195 .label("first_trial_id"),
196 )
197 .where(
198 self._schema.trial.c.exp_id == self._experiment_id,
199 )
200 .scalar_subquery()
201 ),
202 )
203 )
204 min_trial_row = query_results.fetchone()
205 if min_trial_row is not None:
206 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy
207 return min_trial_row._tuple()[0]
208 return None
210 @property
211 def results_df(self) -> pandas.DataFrame:
212 return common.get_results_df(self._engine, self._schema, self._experiment_id)