Coverage for mlos_bench/mlos_bench/storage/sql/experiment_data.py: 96%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""An interface to access the experiment benchmark data stored in SQL DB.""" 

6import logging 

7from typing import Dict, Literal, Optional 

8 

9import pandas 

10from sqlalchemy import Engine, Integer, String, func 

11 

12from mlos_bench.storage.base_experiment_data import ExperimentData 

13from mlos_bench.storage.base_trial_data import TrialData 

14from mlos_bench.storage.base_tunable_config_data import TunableConfigData 

15from mlos_bench.storage.base_tunable_config_trial_group_data import ( 

16 TunableConfigTrialGroupData, 

17) 

18from mlos_bench.storage.sql import common 

19from mlos_bench.storage.sql.schema import DbSchema 

20from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData 

21from mlos_bench.storage.sql.tunable_config_trial_group_data import ( 

22 TunableConfigTrialGroupSqlData, 

23) 

24 

25_LOG = logging.getLogger(__name__) 

26 

27 

28class ExperimentSqlData(ExperimentData): 

29 """ 

30 SQL interface for accessing the stored experiment benchmark data. 

31 

32 An experiment groups together a set of trials that are run with a given set of 

33 scripts and mlos_bench configuration files. 

34 """ 

35 

36 def __init__( # pylint: disable=too-many-arguments 

37 self, 

38 *, 

39 engine: Engine, 

40 schema: DbSchema, 

41 experiment_id: str, 

42 description: str, 

43 root_env_config: str, 

44 git_repo: str, 

45 git_commit: str, 

46 ): 

47 super().__init__( 

48 experiment_id=experiment_id, 

49 description=description, 

50 root_env_config=root_env_config, 

51 git_repo=git_repo, 

52 git_commit=git_commit, 

53 ) 

54 self._engine = engine 

55 self._schema = schema 

56 

57 @property 

58 def objectives(self) -> Dict[str, Literal["min", "max"]]: 

59 with self._engine.connect() as conn: 

60 objectives_db_data = conn.execute( 

61 self._schema.objectives.select() 

62 .where( 

63 self._schema.objectives.c.exp_id == self._experiment_id, 

64 ) 

65 .order_by( 

66 self._schema.objectives.c.weight.desc(), 

67 self._schema.objectives.c.optimization_target.asc(), 

68 ) 

69 ) 

70 return { 

71 objective.optimization_target: objective.optimization_direction 

72 for objective in objectives_db_data.fetchall() 

73 } 

74 

75 # TODO: provide a way to get individual data to avoid repeated bulk fetches 

76 # where only small amounts of data is accessed. 

77 # Or else make the TrialData object lazily populate. 

78 

79 @property 

80 def trials(self) -> Dict[int, TrialData]: 

81 return common.get_trials(self._engine, self._schema, self._experiment_id) 

82 

83 @property 

84 def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]: 

85 with self._engine.connect() as conn: 

86 tunable_config_trial_groups = conn.execute( 

87 self._schema.trial.select() 

88 .with_only_columns( 

89 self._schema.trial.c.config_id, 

90 func.min(self._schema.trial.c.trial_id) 

91 .cast(Integer) 

92 .label("tunable_config_trial_group_id"), # pylint: disable=not-callable 

93 ) 

94 .where( 

95 self._schema.trial.c.exp_id == self._experiment_id, 

96 ) 

97 .group_by( 

98 self._schema.trial.c.exp_id, 

99 self._schema.trial.c.config_id, 

100 ) 

101 ) 

102 return { 

103 tunable_config_trial_group.config_id: TunableConfigTrialGroupSqlData( 

104 engine=self._engine, 

105 schema=self._schema, 

106 experiment_id=self._experiment_id, 

107 tunable_config_id=tunable_config_trial_group.config_id, 

108 tunable_config_trial_group_id=tunable_config_trial_group.tunable_config_trial_group_id, # pylint:disable=line-too-long # noqa 

109 ) 

110 for tunable_config_trial_group in tunable_config_trial_groups.fetchall() 

111 } 

112 

113 @property 

114 def tunable_configs(self) -> Dict[int, TunableConfigData]: 

115 with self._engine.connect() as conn: 

116 tunable_configs = conn.execute( 

117 self._schema.trial.select() 

118 .with_only_columns( 

119 self._schema.trial.c.config_id.cast(Integer).label("config_id"), 

120 ) 

121 .where( 

122 self._schema.trial.c.exp_id == self._experiment_id, 

123 ) 

124 .group_by( 

125 self._schema.trial.c.exp_id, 

126 self._schema.trial.c.config_id, 

127 ) 

128 ) 

129 return { 

130 tunable_config.config_id: TunableConfigSqlData( 

131 engine=self._engine, 

132 schema=self._schema, 

133 tunable_config_id=tunable_config.config_id, 

134 ) 

135 for tunable_config in tunable_configs.fetchall() 

136 } 

137 

138 @property 

139 def default_tunable_config_id(self) -> Optional[int]: 

140 """ 

141 Retrieves the (tunable) config id for the default tunable values for this 

142 experiment. 

143 

144 Note: this is by *default* the first trial executed for this experiment. 

145 However, it is currently possible that the user changed the tunables config 

146 in between resumptions of an experiment. 

147 

148 Returns 

149 ------- 

150 int 

151 """ 

152 with self._engine.connect() as conn: 

153 query_results = conn.execute( 

154 self._schema.trial.select() 

155 .with_only_columns( 

156 self._schema.trial.c.config_id.cast(Integer).label("config_id"), 

157 ) 

158 .where( 

159 self._schema.trial.c.exp_id == self._experiment_id, 

160 self._schema.trial.c.trial_id.in_( 

161 self._schema.trial_param.select() 

162 .with_only_columns( 

163 func.min(self._schema.trial_param.c.trial_id) 

164 .cast(Integer) 

165 .label("first_trial_id_with_defaults"), # pylint: disable=not-callable 

166 ) 

167 .where( 

168 self._schema.trial_param.c.exp_id == self._experiment_id, 

169 self._schema.trial_param.c.param_id == "is_defaults", 

170 func.lower(self._schema.trial_param.c.param_value, type_=String).in_( 

171 ["1", "true"] 

172 ), 

173 ) 

174 .scalar_subquery() 

175 ), 

176 ) 

177 ) 

178 min_default_trial_row = query_results.fetchone() 

179 if min_default_trial_row is not None: 

180 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy 

181 return min_default_trial_row._tuple()[0] 

182 # fallback logic - assume minimum trial_id for experiment 

183 query_results = conn.execute( 

184 self._schema.trial.select() 

185 .with_only_columns( 

186 self._schema.trial.c.config_id.cast(Integer).label("config_id"), 

187 ) 

188 .where( 

189 self._schema.trial.c.exp_id == self._experiment_id, 

190 self._schema.trial.c.trial_id.in_( 

191 self._schema.trial.select() 

192 .with_only_columns( 

193 func.min(self._schema.trial.c.trial_id) 

194 .cast(Integer) 

195 .label("first_trial_id"), 

196 ) 

197 .where( 

198 self._schema.trial.c.exp_id == self._experiment_id, 

199 ) 

200 .scalar_subquery() 

201 ), 

202 ) 

203 ) 

204 min_trial_row = query_results.fetchone() 

205 if min_trial_row is not None: 

206 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy 

207 return min_trial_row._tuple()[0] 

208 return None 

209 

210 @property 

211 def results_df(self) -> pandas.DataFrame: 

212 return common.get_results_df(self._engine, self._schema, self._experiment_id)