Coverage for mlos_bench/mlos_bench/storage/sql/trial_data.py: 100%

37 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""An interface to access the benchmark trial data stored in SQL DB using the 

6:py:class:`.TrialData` interface. 

7""" 

8from datetime import datetime 

9from typing import TYPE_CHECKING 

10 

11import pandas 

12from sqlalchemy.engine import Engine 

13 

14from mlos_bench.environments.status import Status 

15from mlos_bench.storage.base_trial_data import TrialData 

16from mlos_bench.storage.base_tunable_config_data import TunableConfigData 

17from mlos_bench.storage.sql.schema import DbSchema 

18from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData 

19from mlos_bench.util import utcify_timestamp 

20 

21if TYPE_CHECKING: 

22 from mlos_bench.storage.base_tunable_config_trial_group_data import ( 

23 TunableConfigTrialGroupData, 

24 ) 

25 

26 

27class TrialSqlData(TrialData): 

28 """An interface to access the trial data stored in the SQL DB.""" 

29 

30 def __init__( # pylint: disable=too-many-arguments 

31 self, 

32 *, 

33 engine: Engine, 

34 schema: DbSchema, 

35 experiment_id: str, 

36 trial_id: int, 

37 config_id: int, 

38 ts_start: datetime, 

39 ts_end: datetime | None, 

40 status: Status, 

41 trial_runner_id: int | None = None, 

42 ): 

43 super().__init__( 

44 experiment_id=experiment_id, 

45 trial_id=trial_id, 

46 tunable_config_id=config_id, 

47 ts_start=ts_start, 

48 ts_end=ts_end, 

49 status=status, 

50 trial_runner_id=trial_runner_id, 

51 ) 

52 self._engine = engine 

53 self._schema = schema 

54 

55 @property 

56 def tunable_config(self) -> TunableConfigData: 

57 """ 

58 Retrieve the trial's tunable configuration from the storage. 

59 

60 Note: this corresponds to the Trial object's "tunables" property. 

61 """ 

62 return TunableConfigSqlData( 

63 engine=self._engine, 

64 schema=self._schema, 

65 tunable_config_id=self._tunable_config_id, 

66 ) 

67 

68 @property 

69 def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": 

70 """Retrieve the trial's tunable config group configuration data from the 

71 storage. 

72 """ 

73 # pylint: disable=import-outside-toplevel 

74 from mlos_bench.storage.sql.tunable_config_trial_group_data import ( 

75 TunableConfigTrialGroupSqlData, 

76 ) 

77 

78 return TunableConfigTrialGroupSqlData( 

79 engine=self._engine, 

80 schema=self._schema, 

81 experiment_id=self._experiment_id, 

82 tunable_config_id=self._tunable_config_id, 

83 ) 

84 

85 @property 

86 def results_df(self) -> pandas.DataFrame: 

87 """Retrieve the trials' results from the storage.""" 

88 with self._engine.connect() as conn: 

89 cur_results = conn.execute( 

90 self._schema.trial_result.select() 

91 .where( 

92 self._schema.trial_result.c.exp_id == self._experiment_id, 

93 self._schema.trial_result.c.trial_id == self._trial_id, 

94 ) 

95 .order_by( 

96 self._schema.trial_result.c.metric_id, 

97 ) 

98 ) 

99 return pandas.DataFrame( 

100 [(row.metric_id, row.metric_value) for row in cur_results.fetchall()], 

101 columns=["metric", "value"], 

102 ) 

103 

104 @property 

105 def telemetry_df(self) -> pandas.DataFrame: 

106 """Retrieve the trials' telemetry from the storage.""" 

107 with self._engine.connect() as conn: 

108 cur_telemetry = conn.execute( 

109 self._schema.trial_telemetry.select() 

110 .where( 

111 self._schema.trial_telemetry.c.exp_id == self._experiment_id, 

112 self._schema.trial_telemetry.c.trial_id == self._trial_id, 

113 ) 

114 .order_by( 

115 self._schema.trial_telemetry.c.ts, 

116 self._schema.trial_telemetry.c.metric_id, 

117 ) 

118 ) 

119 # Not all storage backends store the original zone info. 

120 # We try to ensure data is entered in UTC and augment it on return again here. 

121 return pandas.DataFrame( 

122 [ 

123 (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) 

124 for row in cur_telemetry.fetchall() 

125 ], 

126 columns=["ts", "metric", "value"], 

127 ) 

128 

129 @property 

130 def metadata_df(self) -> pandas.DataFrame: 

131 """ 

132 Retrieve the trials' metadata params. 

133 

134 Note: this corresponds to the Trial object's "config" property. 

135 """ 

136 with self._engine.connect() as conn: 

137 cur_params = conn.execute( 

138 self._schema.trial_param.select() 

139 .where( 

140 self._schema.trial_param.c.exp_id == self._experiment_id, 

141 self._schema.trial_param.c.trial_id == self._trial_id, 

142 ) 

143 .order_by( 

144 self._schema.trial_param.c.param_id, 

145 ) 

146 ) 

147 return pandas.DataFrame( 

148 [(row.param_id, row.param_value) for row in cur_params.fetchall()], 

149 columns=["parameter", "value"], 

150 )