Coverage for mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py: 100%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-14 01:58 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for loading the TunableConfigTrialGroupData.""" 

6 

7from mlos_bench.storage.base_experiment_data import ExperimentData 

8from mlos_bench.tests.storage import CONFIG_TRIAL_REPEAT_COUNT 

9from mlos_bench.tunables.tunable_groups import TunableGroups 

10 

11 

12def test_tunable_config_trial_group_data(exp_data: ExperimentData) -> None: 

13 """Test basic TunableConfigTrialGroupData properties.""" 

14 trial_id = 1 

15 trial = exp_data.trials[trial_id] 

16 tunable_config_trial_group = trial.tunable_config_trial_group 

17 assert ( 

18 tunable_config_trial_group.experiment_id == exp_data.experiment_id == trial.experiment_id 

19 ) 

20 assert tunable_config_trial_group.tunable_config_id == trial.tunable_config_id 

21 assert tunable_config_trial_group.tunable_config == trial.tunable_config 

22 assert ( 

23 tunable_config_trial_group 

24 == next(iter(tunable_config_trial_group.trials.values())).tunable_config_trial_group 

25 ) 

26 

27 

28def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) -> None: 

29 """ 

30 Test the TunableConfigTrialGroupData property of TrialData. 

31 

32 See Also: 

33 - test_exp_data_tunable_config_trial_group_id_in_results_df() 

34 - test_exp_data_tunable_config_trial_groups() 

35 

36 This tests individual fetching. 

37 """ 

38 # First three trials should use the same config. 

39 trial_1 = exp_data.trials[1] 

40 assert trial_1.tunable_config_id == 1 

41 assert trial_1.tunable_config_trial_group.tunable_config_trial_group_id == 1 

42 

43 trial_2 = exp_data.trials[2] 

44 assert trial_2.tunable_config_id == 1 

45 assert trial_2.tunable_config_trial_group.tunable_config_trial_group_id == 1 

46 

47 # The fourth, should be a new config. 

48 trial_4 = exp_data.trials[4] 

49 assert trial_4.tunable_config_id == 2 

50 assert trial_4.tunable_config_trial_group.tunable_config_trial_group_id == 4 

51 

52 # And so on ... 

53 

54 

55def test_tunable_config_trial_group_results_df( 

56 exp_data: ExperimentData, 

57 tunable_groups: TunableGroups, 

58) -> None: 

59 """Tests the results_df property of the TunableConfigTrialGroup.""" 

60 tunable_config_id = 2 

61 expected_group_id = 4 

62 tunable_config_trial_group = exp_data.tunable_config_trial_groups[tunable_config_id] 

63 results_df = tunable_config_trial_group.results_df 

64 # We shouldn't have the results for the other configs, just this one. 

65 expected_count = CONFIG_TRIAL_REPEAT_COUNT 

66 assert len(results_df) == expected_count 

67 assert ( 

68 len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) == expected_count 

69 ) 

70 assert len(results_df[(results_df["tunable_config_id"] != tunable_config_id)]) == 0 

71 assert ( 

72 len(results_df[(results_df["tunable_config_trial_group_id"] == expected_group_id)]) 

73 == expected_count 

74 ) 

75 assert len(results_df[(results_df["tunable_config_trial_group_id"] != expected_group_id)]) == 0 

76 assert len(results_df["trial_id"].unique()) == expected_count 

77 obj_target = next(iter(exp_data.objectives)) 

78 assert len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_count 

79 (tunable, _covariant_group) = next(iter(tunable_groups)) 

80 assert len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) == expected_count 

81 

82 

83def test_tunable_config_trial_group_trials(exp_data: ExperimentData) -> None: 

84 """Tests the trials property of the TunableConfigTrialGroup.""" 

85 tunable_config_id = 2 

86 expected_group_id = 4 

87 tunable_config_trial_group = exp_data.tunable_config_trial_groups[tunable_config_id] 

88 trials = tunable_config_trial_group.trials 

89 assert len(trials) == CONFIG_TRIAL_REPEAT_COUNT 

90 assert all( 

91 trial.tunable_config_trial_group.tunable_config_trial_group_id == expected_group_id 

92 for trial in trials.values() 

93 ) 

94 assert all( 

95 trial.tunable_config_id == tunable_config_id 

96 for trial in tunable_config_trial_group.trials.values() 

97 ) 

98 assert ( 

99 exp_data.trials[expected_group_id] == tunable_config_trial_group.trials[expected_group_id] 

100 )