Coverage for mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py: 100%
28 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""An interface to access the tunable config trial group data stored in SQL DB."""
7from typing import TYPE_CHECKING, Dict, Optional
9import pandas
10from sqlalchemy import Engine, Integer, func
12from mlos_bench.storage.base_tunable_config_data import TunableConfigData
13from mlos_bench.storage.base_tunable_config_trial_group_data import (
14 TunableConfigTrialGroupData,
15)
16from mlos_bench.storage.sql import common
17from mlos_bench.storage.sql.schema import DbSchema
18from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData
20if TYPE_CHECKING:
21 from mlos_bench.storage.base_trial_data import TrialData
24class TunableConfigTrialGroupSqlData(TunableConfigTrialGroupData):
25 """
26 SQL interface for accessing the stored experiment benchmark tunable config trial
27 group data.
29 A (tunable) config is used to define an instance of values for a set of tunable
30 parameters for a given experiment and can be used by one or more trial instances
31 (e.g., for repeats), which we call a (tunable) config trial group.
32 """
34 def __init__( # pylint: disable=too-many-arguments
35 self,
36 *,
37 engine: Engine,
38 schema: DbSchema,
39 experiment_id: str,
40 tunable_config_id: int,
41 tunable_config_trial_group_id: Optional[int] = None,
42 ):
43 super().__init__(
44 experiment_id=experiment_id,
45 tunable_config_id=tunable_config_id,
46 tunable_config_trial_group_id=tunable_config_trial_group_id,
47 )
48 self._engine = engine
49 self._schema = schema
51 def _get_tunable_config_trial_group_id(self) -> int:
52 """Retrieve the trial's tunable_config_trial_group_id from the storage."""
53 with self._engine.connect() as conn:
54 tunable_config_trial_group = conn.execute(
55 self._schema.trial.select()
56 .with_only_columns(
57 func.min(self._schema.trial.c.trial_id)
58 .cast(Integer)
59 .label("tunable_config_trial_group_id"), # pylint: disable=not-callable
60 )
61 .where(
62 self._schema.trial.c.exp_id == self._experiment_id,
63 self._schema.trial.c.config_id == self._tunable_config_id,
64 )
65 .group_by(
66 self._schema.trial.c.exp_id,
67 self._schema.trial.c.config_id,
68 )
69 )
70 row = tunable_config_trial_group.fetchone()
71 assert row is not None
72 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy
73 return row._tuple()[0]
75 @property
76 def tunable_config(self) -> TunableConfigData:
77 return TunableConfigSqlData(
78 engine=self._engine,
79 schema=self._schema,
80 tunable_config_id=self.tunable_config_id,
81 )
83 @property
84 def trials(self) -> Dict[int, "TrialData"]:
85 """
86 Retrieve the trials' data for this (tunable) config trial group from the
87 storage.
89 Returns
90 -------
91 trials : Dict[int, TrialData]
92 A dictionary of the trials' data, keyed by trial id.
93 """
94 return common.get_trials(
95 self._engine,
96 self._schema,
97 self._experiment_id,
98 self._tunable_config_id,
99 )
101 @property
102 def results_df(self) -> pandas.DataFrame:
103 return common.get_results_df(
104 self._engine,
105 self._schema,
106 self._experiment_id,
107 self._tunable_config_id,
108 )