Coverage for mlos_bench/mlos_bench/storage/sql/trial_data.py: 100%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""An interface to access the benchmark trial data stored in SQL DB using the 

6:py:class:`.TrialData` interface. 

7""" 

8from datetime import datetime 

9from typing import TYPE_CHECKING, Optional 

10 

11import pandas 

12from sqlalchemy.engine import Engine 

13 

14from mlos_bench.environments.status import Status 

15from mlos_bench.storage.base_trial_data import TrialData 

16from mlos_bench.storage.base_tunable_config_data import TunableConfigData 

17from mlos_bench.storage.sql.schema import DbSchema 

18from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData 

19from mlos_bench.util import utcify_timestamp 

20 

21if TYPE_CHECKING: 

22 from mlos_bench.storage.base_tunable_config_trial_group_data import ( 

23 TunableConfigTrialGroupData, 

24 ) 

25 

26 

27class TrialSqlData(TrialData): 

28 """An interface to access the trial data stored in the SQL DB.""" 

29 

30 def __init__( # pylint: disable=too-many-arguments 

31 self, 

32 *, 

33 engine: Engine, 

34 schema: DbSchema, 

35 experiment_id: str, 

36 trial_id: int, 

37 config_id: int, 

38 ts_start: datetime, 

39 ts_end: Optional[datetime], 

40 status: Status, 

41 ): 

42 super().__init__( 

43 experiment_id=experiment_id, 

44 trial_id=trial_id, 

45 tunable_config_id=config_id, 

46 ts_start=ts_start, 

47 ts_end=ts_end, 

48 status=status, 

49 ) 

50 self._engine = engine 

51 self._schema = schema 

52 

53 @property 

54 def tunable_config(self) -> TunableConfigData: 

55 """ 

56 Retrieve the trial's tunable configuration from the storage. 

57 

58 Note: this corresponds to the Trial object's "tunables" property. 

59 """ 

60 return TunableConfigSqlData( 

61 engine=self._engine, 

62 schema=self._schema, 

63 tunable_config_id=self._tunable_config_id, 

64 ) 

65 

66 @property 

67 def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": 

68 """Retrieve the trial's tunable config group configuration data from the 

69 storage. 

70 """ 

71 # pylint: disable=import-outside-toplevel 

72 from mlos_bench.storage.sql.tunable_config_trial_group_data import ( 

73 TunableConfigTrialGroupSqlData, 

74 ) 

75 

76 return TunableConfigTrialGroupSqlData( 

77 engine=self._engine, 

78 schema=self._schema, 

79 experiment_id=self._experiment_id, 

80 tunable_config_id=self._tunable_config_id, 

81 ) 

82 

83 @property 

84 def results_df(self) -> pandas.DataFrame: 

85 """Retrieve the trials' results from the storage.""" 

86 with self._engine.connect() as conn: 

87 cur_results = conn.execute( 

88 self._schema.trial_result.select() 

89 .where( 

90 self._schema.trial_result.c.exp_id == self._experiment_id, 

91 self._schema.trial_result.c.trial_id == self._trial_id, 

92 ) 

93 .order_by( 

94 self._schema.trial_result.c.metric_id, 

95 ) 

96 ) 

97 return pandas.DataFrame( 

98 [(row.metric_id, row.metric_value) for row in cur_results.fetchall()], 

99 columns=["metric", "value"], 

100 ) 

101 

102 @property 

103 def telemetry_df(self) -> pandas.DataFrame: 

104 """Retrieve the trials' telemetry from the storage.""" 

105 with self._engine.connect() as conn: 

106 cur_telemetry = conn.execute( 

107 self._schema.trial_telemetry.select() 

108 .where( 

109 self._schema.trial_telemetry.c.exp_id == self._experiment_id, 

110 self._schema.trial_telemetry.c.trial_id == self._trial_id, 

111 ) 

112 .order_by( 

113 self._schema.trial_telemetry.c.ts, 

114 self._schema.trial_telemetry.c.metric_id, 

115 ) 

116 ) 

117 # Not all storage backends store the original zone info. 

118 # We try to ensure data is entered in UTC and augment it on return again here. 

119 return pandas.DataFrame( 

120 [ 

121 (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) 

122 for row in cur_telemetry.fetchall() 

123 ], 

124 columns=["ts", "metric", "value"], 

125 ) 

126 

127 @property 

128 def metadata_df(self) -> pandas.DataFrame: 

129 """ 

130 Retrieve the trials' metadata params. 

131 

132 Note: this corresponds to the Trial object's "config" property. 

133 """ 

134 with self._engine.connect() as conn: 

135 cur_params = conn.execute( 

136 self._schema.trial_param.select() 

137 .where( 

138 self._schema.trial_param.c.exp_id == self._experiment_id, 

139 self._schema.trial_param.c.trial_id == self._trial_id, 

140 ) 

141 .order_by( 

142 self._schema.trial_param.c.param_id, 

143 ) 

144 ) 

145 return pandas.DataFrame( 

146 [(row.param_id, row.param_value) for row in cur_params.fetchall()], 

147 columns=["parameter", "value"], 

148 )