Coverage for mlos_bench/mlos_bench/storage/sql/experiment_data.py: 90%
52 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""An interface to access the benchmark experiment data stored in SQL DB using the
6:py:class:`.ExperimentData` interface.
7"""
8import logging
9from typing import Dict, Literal, Optional
11import pandas
12from sqlalchemy import Integer, String, func
13from sqlalchemy.engine import Engine
15from mlos_bench.storage.base_experiment_data import ExperimentData
16from mlos_bench.storage.base_trial_data import TrialData
17from mlos_bench.storage.base_tunable_config_data import TunableConfigData
18from mlos_bench.storage.base_tunable_config_trial_group_data import (
19 TunableConfigTrialGroupData,
20)
21from mlos_bench.storage.sql import common
22from mlos_bench.storage.sql.schema import DbSchema
23from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData
24from mlos_bench.storage.sql.tunable_config_trial_group_data import (
25 TunableConfigTrialGroupSqlData,
26)
28_LOG = logging.getLogger(__name__)
31class ExperimentSqlData(ExperimentData):
32 """
33 SQL interface for accessing the stored experiment benchmark data.
35 An experiment groups together a set of trials that are run with a given set of
36 scripts and mlos_bench configuration files.
37 """
39 def __init__( # pylint: disable=too-many-arguments
40 self,
41 *,
42 engine: Engine,
43 schema: DbSchema,
44 experiment_id: str,
45 description: str,
46 root_env_config: str,
47 git_repo: str,
48 git_commit: str,
49 ):
50 super().__init__(
51 experiment_id=experiment_id,
52 description=description,
53 root_env_config=root_env_config,
54 git_repo=git_repo,
55 git_commit=git_commit,
56 )
57 self._engine = engine
58 self._schema = schema
60 @property
61 def objectives(self) -> Dict[str, Literal["min", "max"]]:
62 with self._engine.connect() as conn:
63 objectives_db_data = conn.execute(
64 self._schema.objectives.select()
65 .where(
66 self._schema.objectives.c.exp_id == self._experiment_id,
67 )
68 .order_by(
69 self._schema.objectives.c.weight.desc(),
70 self._schema.objectives.c.optimization_target.asc(),
71 )
72 )
73 return {
74 objective.optimization_target: objective.optimization_direction
75 for objective in objectives_db_data.fetchall()
76 }
78 # TODO: provide a way to get individual data to avoid repeated bulk fetches
79 # where only small amounts of data is accessed.
80 # Or else make the TrialData object lazily populate.
82 @property
83 def trials(self) -> Dict[int, TrialData]:
84 return common.get_trials(self._engine, self._schema, self._experiment_id)
86 @property
87 def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]:
88 with self._engine.connect() as conn:
89 tunable_config_trial_groups = conn.execute(
90 self._schema.trial.select()
91 .with_only_columns(
92 self._schema.trial.c.config_id,
93 func.min(self._schema.trial.c.trial_id)
94 .cast(Integer)
95 .label("tunable_config_trial_group_id"), # pylint: disable=not-callable
96 )
97 .where(
98 self._schema.trial.c.exp_id == self._experiment_id,
99 )
100 .group_by(
101 self._schema.trial.c.exp_id,
102 self._schema.trial.c.config_id,
103 )
104 )
105 return {
106 tunable_config_trial_group.config_id: TunableConfigTrialGroupSqlData(
107 engine=self._engine,
108 schema=self._schema,
109 experiment_id=self._experiment_id,
110 tunable_config_id=tunable_config_trial_group.config_id,
111 tunable_config_trial_group_id=tunable_config_trial_group.tunable_config_trial_group_id, # pylint:disable=line-too-long # noqa
112 )
113 for tunable_config_trial_group in tunable_config_trial_groups.fetchall()
114 }
116 @property
117 def tunable_configs(self) -> Dict[int, TunableConfigData]:
118 with self._engine.connect() as conn:
119 tunable_configs = conn.execute(
120 self._schema.trial.select()
121 .with_only_columns(
122 self._schema.trial.c.config_id.cast(Integer).label("config_id"),
123 )
124 .where(
125 self._schema.trial.c.exp_id == self._experiment_id,
126 )
127 .group_by(
128 self._schema.trial.c.exp_id,
129 self._schema.trial.c.config_id,
130 )
131 )
132 return {
133 tunable_config.config_id: TunableConfigSqlData(
134 engine=self._engine,
135 schema=self._schema,
136 tunable_config_id=tunable_config.config_id,
137 )
138 for tunable_config in tunable_configs.fetchall()
139 }
141 @property
142 def default_tunable_config_id(self) -> Optional[int]:
143 """
144 Retrieves the (tunable) config id for the default tunable values for this
145 experiment.
147 Note: this is by *default* the first trial executed for this experiment.
148 However, it is currently possible that the user changed the tunables config
149 in between resumptions of an experiment.
151 Returns
152 -------
153 int
154 """
155 with self._engine.connect() as conn:
156 query_results = conn.execute(
157 self._schema.trial.select()
158 .with_only_columns(
159 self._schema.trial.c.config_id.cast(Integer).label("config_id"),
160 )
161 .where(
162 self._schema.trial.c.exp_id == self._experiment_id,
163 self._schema.trial.c.trial_id.in_(
164 self._schema.trial_param.select()
165 .with_only_columns(
166 func.min(self._schema.trial_param.c.trial_id)
167 .cast(Integer)
168 .label("first_trial_id_with_defaults"), # pylint: disable=not-callable
169 )
170 .where(
171 self._schema.trial_param.c.exp_id == self._experiment_id,
172 self._schema.trial_param.c.param_id == "is_defaults",
173 func.lower(self._schema.trial_param.c.param_value, type_=String).in_(
174 ["1", "true"]
175 ),
176 )
177 .scalar_subquery()
178 ),
179 )
180 )
181 min_default_trial_row = query_results.fetchone()
182 if min_default_trial_row is not None:
183 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy
184 return min_default_trial_row._tuple()[0]
185 # fallback logic - assume minimum trial_id for experiment
186 query_results = conn.execute(
187 self._schema.trial.select()
188 .with_only_columns(
189 self._schema.trial.c.config_id.cast(Integer).label("config_id"),
190 )
191 .where(
192 self._schema.trial.c.exp_id == self._experiment_id,
193 self._schema.trial.c.trial_id.in_(
194 self._schema.trial.select()
195 .with_only_columns(
196 func.min(self._schema.trial.c.trial_id)
197 .cast(Integer)
198 .label("first_trial_id"),
199 )
200 .where(
201 self._schema.trial.c.exp_id == self._experiment_id,
202 )
203 .scalar_subquery()
204 ),
205 )
206 )
207 min_trial_row = query_results.fetchone()
208 if min_trial_row is not None:
209 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy
210 return min_trial_row._tuple()[0]
211 return None
213 @property
214 def results_df(self) -> pandas.DataFrame:
215 return common.get_results_df(self._engine, self._schema, self._experiment_id)