Coverage for mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py: 100%
29 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""An interface to access the tunable config trial group data stored in a SQL DB using
6the :py:class:`.TunableConfigTrialGroupData` interface.
7"""
9from typing import TYPE_CHECKING, Dict, Optional
11import pandas
12from sqlalchemy import Integer, func
13from sqlalchemy.engine import Engine
15from mlos_bench.storage.base_tunable_config_data import TunableConfigData
16from mlos_bench.storage.base_tunable_config_trial_group_data import (
17 TunableConfigTrialGroupData,
18)
19from mlos_bench.storage.sql import common
20from mlos_bench.storage.sql.schema import DbSchema
21from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData
23if TYPE_CHECKING:
24 from mlos_bench.storage.base_trial_data import TrialData
27class TunableConfigTrialGroupSqlData(TunableConfigTrialGroupData):
28 """
29 SQL interface for accessing the stored experiment benchmark tunable config trial
30 group data.
32 A (tunable) config is used to define an instance of values for a set of tunable
33 parameters for a given experiment and can be used by one or more trial instances
34 (e.g., for repeats), which we call a (tunable) config trial group.
35 """
37 def __init__( # pylint: disable=too-many-arguments
38 self,
39 *,
40 engine: Engine,
41 schema: DbSchema,
42 experiment_id: str,
43 tunable_config_id: int,
44 tunable_config_trial_group_id: Optional[int] = None,
45 ):
46 super().__init__(
47 experiment_id=experiment_id,
48 tunable_config_id=tunable_config_id,
49 tunable_config_trial_group_id=tunable_config_trial_group_id,
50 )
51 self._engine = engine
52 self._schema = schema
54 def _get_tunable_config_trial_group_id(self) -> int:
55 """Retrieve the trial's tunable_config_trial_group_id from the storage."""
56 with self._engine.connect() as conn:
57 tunable_config_trial_group = conn.execute(
58 self._schema.trial.select()
59 .with_only_columns(
60 func.min(self._schema.trial.c.trial_id)
61 .cast(Integer)
62 .label("tunable_config_trial_group_id"), # pylint: disable=not-callable
63 )
64 .where(
65 self._schema.trial.c.exp_id == self._experiment_id,
66 self._schema.trial.c.config_id == self._tunable_config_id,
67 )
68 .group_by(
69 self._schema.trial.c.exp_id,
70 self._schema.trial.c.config_id,
71 )
72 )
73 row = tunable_config_trial_group.fetchone()
74 assert row is not None
75 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy
76 return row._tuple()[0]
78 @property
79 def tunable_config(self) -> TunableConfigData:
80 return TunableConfigSqlData(
81 engine=self._engine,
82 schema=self._schema,
83 tunable_config_id=self.tunable_config_id,
84 )
86 @property
87 def trials(self) -> Dict[int, "TrialData"]:
88 """
89 Retrieve the trials' data for this (tunable) config trial group from the
90 storage.
92 Returns
93 -------
94 trials : Dict[int, TrialData]
95 A dictionary of the trials' data, keyed by trial id.
96 """
97 return common.get_trials(
98 self._engine,
99 self._schema,
100 self._experiment_id,
101 self._tunable_config_id,
102 )
104 @property
105 def results_df(self) -> pandas.DataFrame:
106 return common.get_results_df(
107 self._engine,
108 self._schema,
109 self._experiment_id,
110 self._tunable_config_id,
111 )