Coverage for mlos_bench/mlos_bench/tests/storage/exp_data_test.py: 100%

61 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for loading the experiment metadata.""" 

6 

7from mlos_bench.storage.base_experiment_data import ExperimentData 

8from mlos_bench.storage.base_storage import Storage 

9from mlos_bench.tests.storage import CONFIG_COUNT, CONFIG_TRIAL_REPEAT_COUNT 

10from mlos_bench.tunables.tunable_groups import TunableGroups 

11 

12 

13def test_load_empty_exp_data(storage: Storage, exp_storage: Storage.Experiment) -> None: 

14 """Try to retrieve old experimental data from the empty storage.""" 

15 exp = storage.experiments[exp_storage.experiment_id] 

16 assert exp.experiment_id == exp_storage.experiment_id 

17 assert exp.description == exp_storage.description 

18 assert exp.objectives == exp_storage.opt_targets 

19 

20 

21def test_exp_data_root_env_config( 

22 exp_storage: Storage.Experiment, 

23 exp_data: ExperimentData, 

24) -> None: 

25 """Tests the root_env_config property of ExperimentData.""" 

26 # pylint: disable=protected-access 

27 assert exp_data.root_env_config == ( 

28 exp_storage._root_env_config, 

29 exp_storage._git_repo, 

30 exp_storage._git_commit, 

31 ) 

32 

33 

34def test_exp_trial_data_objectives( 

35 storage: Storage, 

36 exp_storage: Storage.Experiment, 

37 tunable_groups: TunableGroups, 

38) -> None: 

39 """Start a new trial and check the storage for the trial data.""" 

40 

41 trial_opt_new = exp_storage.new_trial( 

42 tunable_groups, 

43 config={ 

44 "opt_target": "some-other-target", 

45 "opt_direction": "max", 

46 }, 

47 ) 

48 assert trial_opt_new.config() == { 

49 "experiment_id": exp_storage.experiment_id, 

50 "trial_id": trial_opt_new.trial_id, 

51 "opt_target": "some-other-target", 

52 "opt_direction": "max", 

53 } 

54 

55 trial_opt_old = exp_storage.new_trial( 

56 tunable_groups, 

57 config={ 

58 "opt_target": "back-compat", 

59 # "opt_direction": "max", # missing 

60 }, 

61 ) 

62 assert trial_opt_old.config() == { 

63 "experiment_id": exp_storage.experiment_id, 

64 "trial_id": trial_opt_old.trial_id, 

65 "opt_target": "back-compat", 

66 } 

67 

68 exp = storage.experiments[exp_storage.experiment_id] 

69 assert exp.objectives == exp_storage.opt_targets 

70 

71 trial_data_opt_new = exp.trials[trial_opt_new.trial_id] 

72 assert trial_data_opt_new.metadata_dict == { 

73 "opt_target": "some-other-target", 

74 "opt_direction": "max", 

75 } 

76 

77 

78def test_exp_data_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None: 

79 """Tests the results_df property of ExperimentData.""" 

80 results_df = exp_data.results_df 

81 expected_trials_count = CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT 

82 assert len(results_df) == expected_trials_count 

83 assert len(results_df["tunable_config_id"].unique()) == CONFIG_COUNT 

84 assert len(results_df["trial_id"].unique()) == expected_trials_count 

85 obj_target = next(iter(exp_data.objectives)) 

86 assert ( 

87 len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count 

88 ) 

89 (tunable, _covariant_group) = next(iter(tunable_groups)) 

90 assert ( 

91 len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) 

92 == expected_trials_count 

93 ) 

94 

95 

96def test_exp_no_tunables_data_results_df(exp_no_tunables_data: ExperimentData) -> None: 

97 """Tests the results_df property of ExperimentData when there are no tunables.""" 

98 results_df = exp_no_tunables_data.results_df 

99 expected_trials_count = CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT 

100 assert len(results_df) == expected_trials_count 

101 assert len(results_df["trial_id"].unique()) == expected_trials_count 

102 obj_target = next(iter(exp_no_tunables_data.objectives)) 

103 assert ( 

104 len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count 

105 ) 

106 assert not results_df.columns.str.startswith(ExperimentData.CONFIG_COLUMN_PREFIX).any() 

107 

108 

109def test_exp_data_tunable_config_trial_group_id_in_results_df(exp_data: ExperimentData) -> None: 

110 """ 

111 Tests the tunable_config_trial_group_id property of ExperimentData.results_df. 

112 

113 See Also: test_exp_trial_data_tunable_config_trial_group_id() 

114 """ 

115 results_df = exp_data.results_df 

116 

117 # First three trials should use the same config. 

118 trial_1_df = results_df.loc[(results_df["trial_id"] == 1)] 

119 assert len(trial_1_df) == 1 

120 assert trial_1_df["tunable_config_id"].iloc[0] == 1 

121 assert trial_1_df["tunable_config_trial_group_id"].iloc[0] == 1 

122 

123 trial_2_df = results_df.loc[(results_df["trial_id"] == 2)] 

124 assert len(trial_2_df) == 1 

125 assert trial_2_df["tunable_config_id"].iloc[0] == 1 

126 assert trial_2_df["tunable_config_trial_group_id"].iloc[0] == 1 

127 

128 # The fourth, should be a new config. 

129 trial_4_df = results_df.loc[(results_df["trial_id"] == 4)] 

130 assert len(trial_4_df) == 1 

131 assert trial_4_df["tunable_config_id"].iloc[0] == 2 

132 assert trial_4_df["tunable_config_trial_group_id"].iloc[0] == 4 

133 

134 # And so on ... 

135 

136 

137def test_exp_data_tunable_config_trial_groups(exp_data: ExperimentData) -> None: 

138 """ 

139 Tests the tunable_config_trial_groups property of ExperimentData. 

140 

141 This tests bulk loading of the tunable_config_trial_groups. 

142 """ 

143 # Should be keyed by config_id. 

144 assert list(exp_data.tunable_config_trial_groups.keys()) == list(range(1, CONFIG_COUNT + 1)) 

145 # Which should match the objects. 

146 assert [ 

147 config_trial_group.tunable_config_id 

148 for config_trial_group in exp_data.tunable_config_trial_groups.values() 

149 ] == list(range(1, CONFIG_COUNT + 1)) 

150 # And the tunable_config_trial_group_id should also match the minimum trial_id. 

151 assert [ 

152 config_trial_group.tunable_config_trial_group_id 

153 for config_trial_group in exp_data.tunable_config_trial_groups.values() 

154 ] == list(range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT)) 

155 

156 

157def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None: 

158 """Tests the tunable_configs property of ExperimentData.""" 

159 # Should be keyed by config_id. 

160 assert list(exp_data.tunable_configs.keys()) == list(range(1, CONFIG_COUNT + 1)) 

161 # Which should match the objects. 

162 assert [config.tunable_config_id for config in exp_data.tunable_configs.values()] == list( 

163 range(1, CONFIG_COUNT + 1) 

164 ) 

165 

166 

167def test_exp_data_default_config_id(exp_data: ExperimentData) -> None: 

168 """Tests the default_tunable_config_id property of ExperimentData.""" 

169 assert exp_data.default_tunable_config_id == 1