Coverage for mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py: 100%

28 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""An interface to access the tunable config trial group data stored in SQL DB.""" 

6 

7from typing import TYPE_CHECKING, Dict, Optional 

8 

9import pandas 

10from sqlalchemy import Engine, Integer, func 

11 

12from mlos_bench.storage.base_tunable_config_data import TunableConfigData 

13from mlos_bench.storage.base_tunable_config_trial_group_data import ( 

14 TunableConfigTrialGroupData, 

15) 

16from mlos_bench.storage.sql import common 

17from mlos_bench.storage.sql.schema import DbSchema 

18from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData 

19 

20if TYPE_CHECKING: 

21 from mlos_bench.storage.base_trial_data import TrialData 

22 

23 

24class TunableConfigTrialGroupSqlData(TunableConfigTrialGroupData): 

25 """ 

26 SQL interface for accessing the stored experiment benchmark tunable config trial 

27 group data. 

28 

29 A (tunable) config is used to define an instance of values for a set of tunable 

30 parameters for a given experiment and can be used by one or more trial instances 

31 (e.g., for repeats), which we call a (tunable) config trial group. 

32 """ 

33 

34 def __init__( # pylint: disable=too-many-arguments 

35 self, 

36 *, 

37 engine: Engine, 

38 schema: DbSchema, 

39 experiment_id: str, 

40 tunable_config_id: int, 

41 tunable_config_trial_group_id: Optional[int] = None, 

42 ): 

43 super().__init__( 

44 experiment_id=experiment_id, 

45 tunable_config_id=tunable_config_id, 

46 tunable_config_trial_group_id=tunable_config_trial_group_id, 

47 ) 

48 self._engine = engine 

49 self._schema = schema 

50 

51 def _get_tunable_config_trial_group_id(self) -> int: 

52 """Retrieve the trial's tunable_config_trial_group_id from the storage.""" 

53 with self._engine.connect() as conn: 

54 tunable_config_trial_group = conn.execute( 

55 self._schema.trial.select() 

56 .with_only_columns( 

57 func.min(self._schema.trial.c.trial_id) 

58 .cast(Integer) 

59 .label("tunable_config_trial_group_id"), # pylint: disable=not-callable 

60 ) 

61 .where( 

62 self._schema.trial.c.exp_id == self._experiment_id, 

63 self._schema.trial.c.config_id == self._tunable_config_id, 

64 ) 

65 .group_by( 

66 self._schema.trial.c.exp_id, 

67 self._schema.trial.c.config_id, 

68 ) 

69 ) 

70 row = tunable_config_trial_group.fetchone() 

71 assert row is not None 

72 # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy 

73 return row._tuple()[0] 

74 

75 @property 

76 def tunable_config(self) -> TunableConfigData: 

77 return TunableConfigSqlData( 

78 engine=self._engine, 

79 schema=self._schema, 

80 tunable_config_id=self.tunable_config_id, 

81 ) 

82 

83 @property 

84 def trials(self) -> Dict[int, "TrialData"]: 

85 """ 

86 Retrieve the trials' data for this (tunable) config trial group from the 

87 storage. 

88 

89 Returns 

90 ------- 

91 trials : Dict[int, TrialData] 

92 A dictionary of the trials' data, keyed by trial id. 

93 """ 

94 return common.get_trials( 

95 self._engine, 

96 self._schema, 

97 self._experiment_id, 

98 self._tunable_config_id, 

99 ) 

100 

101 @property 

102 def results_df(self) -> pandas.DataFrame: 

103 return common.get_results_df( 

104 self._engine, 

105 self._schema, 

106 self._experiment_id, 

107 self._tunable_config_id, 

108 )