Coverage for mlos_bench/mlos_bench/storage/sql/trial_data.py: 100%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""An interface to access the benchmark trial data stored in SQL DB.""" 

6from datetime import datetime 

7from typing import TYPE_CHECKING, Optional 

8 

9import pandas 

10from sqlalchemy import Engine 

11 

12from mlos_bench.environments.status import Status 

13from mlos_bench.storage.base_trial_data import TrialData 

14from mlos_bench.storage.base_tunable_config_data import TunableConfigData 

15from mlos_bench.storage.sql.schema import DbSchema 

16from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData 

17from mlos_bench.util import utcify_timestamp 

18 

19if TYPE_CHECKING: 

20 from mlos_bench.storage.base_tunable_config_trial_group_data import ( 

21 TunableConfigTrialGroupData, 

22 ) 

23 

24 

25class TrialSqlData(TrialData): 

26 """An interface to access the trial data stored in the SQL DB.""" 

27 

28 def __init__( # pylint: disable=too-many-arguments 

29 self, 

30 *, 

31 engine: Engine, 

32 schema: DbSchema, 

33 experiment_id: str, 

34 trial_id: int, 

35 config_id: int, 

36 ts_start: datetime, 

37 ts_end: Optional[datetime], 

38 status: Status, 

39 ): 

40 super().__init__( 

41 experiment_id=experiment_id, 

42 trial_id=trial_id, 

43 tunable_config_id=config_id, 

44 ts_start=ts_start, 

45 ts_end=ts_end, 

46 status=status, 

47 ) 

48 self._engine = engine 

49 self._schema = schema 

50 

51 @property 

52 def tunable_config(self) -> TunableConfigData: 

53 """ 

54 Retrieve the trial's tunable configuration from the storage. 

55 

56 Note: this corresponds to the Trial object's "tunables" property. 

57 """ 

58 return TunableConfigSqlData( 

59 engine=self._engine, 

60 schema=self._schema, 

61 tunable_config_id=self._tunable_config_id, 

62 ) 

63 

64 @property 

65 def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": 

66 """Retrieve the trial's tunable config group configuration data from the 

67 storage. 

68 """ 

69 # pylint: disable=import-outside-toplevel 

70 from mlos_bench.storage.sql.tunable_config_trial_group_data import ( 

71 TunableConfigTrialGroupSqlData, 

72 ) 

73 

74 return TunableConfigTrialGroupSqlData( 

75 engine=self._engine, 

76 schema=self._schema, 

77 experiment_id=self._experiment_id, 

78 tunable_config_id=self._tunable_config_id, 

79 ) 

80 

81 @property 

82 def results_df(self) -> pandas.DataFrame: 

83 """Retrieve the trials' results from the storage.""" 

84 with self._engine.connect() as conn: 

85 cur_results = conn.execute( 

86 self._schema.trial_result.select() 

87 .where( 

88 self._schema.trial_result.c.exp_id == self._experiment_id, 

89 self._schema.trial_result.c.trial_id == self._trial_id, 

90 ) 

91 .order_by( 

92 self._schema.trial_result.c.metric_id, 

93 ) 

94 ) 

95 return pandas.DataFrame( 

96 [(row.metric_id, row.metric_value) for row in cur_results.fetchall()], 

97 columns=["metric", "value"], 

98 ) 

99 

100 @property 

101 def telemetry_df(self) -> pandas.DataFrame: 

102 """Retrieve the trials' telemetry from the storage.""" 

103 with self._engine.connect() as conn: 

104 cur_telemetry = conn.execute( 

105 self._schema.trial_telemetry.select() 

106 .where( 

107 self._schema.trial_telemetry.c.exp_id == self._experiment_id, 

108 self._schema.trial_telemetry.c.trial_id == self._trial_id, 

109 ) 

110 .order_by( 

111 self._schema.trial_telemetry.c.ts, 

112 self._schema.trial_telemetry.c.metric_id, 

113 ) 

114 ) 

115 # Not all storage backends store the original zone info. 

116 # We try to ensure data is entered in UTC and augment it on return again here. 

117 return pandas.DataFrame( 

118 [ 

119 (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) 

120 for row in cur_telemetry.fetchall() 

121 ], 

122 columns=["ts", "metric", "value"], 

123 ) 

124 

125 @property 

126 def metadata_df(self) -> pandas.DataFrame: 

127 """ 

128 Retrieve the trials' metadata params. 

129 

130 Note: this corresponds to the Trial object's "config" property. 

131 """ 

132 with self._engine.connect() as conn: 

133 cur_params = conn.execute( 

134 self._schema.trial_param.select() 

135 .where( 

136 self._schema.trial_param.c.exp_id == self._experiment_id, 

137 self._schema.trial_param.c.trial_id == self._trial_id, 

138 ) 

139 .order_by( 

140 self._schema.trial_param.c.param_id, 

141 ) 

142 ) 

143 return pandas.DataFrame( 

144 [(row.param_id, row.param_value) for row in cur_params.fetchall()], 

145 columns=["parameter", "value"], 

146 )