Coverage for mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py: 100%
46 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for loading the TunableConfigTrialGroupData."""
7from mlos_bench.storage.base_experiment_data import ExperimentData
8from mlos_bench.tests.storage import CONFIG_TRIAL_REPEAT_COUNT
9from mlos_bench.tunables.tunable_groups import TunableGroups
12def test_tunable_config_trial_group_data(exp_data: ExperimentData) -> None:
13 """Test basic TunableConfigTrialGroupData properties."""
14 trial_id = 1
15 trial = exp_data.trials[trial_id]
16 tunable_config_trial_group = trial.tunable_config_trial_group
17 assert (
18 tunable_config_trial_group.experiment_id == exp_data.experiment_id == trial.experiment_id
19 )
20 assert tunable_config_trial_group.tunable_config_id == trial.tunable_config_id
21 assert tunable_config_trial_group.tunable_config == trial.tunable_config
22 assert (
23 tunable_config_trial_group
24 == next(iter(tunable_config_trial_group.trials.values())).tunable_config_trial_group
25 )
28def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) -> None:
29 """
30 Test the TunableConfigTrialGroupData property of TrialData.
32 See Also:
33 - test_exp_data_tunable_config_trial_group_id_in_results_df()
34 - test_exp_data_tunable_config_trial_groups()
36 This tests individual fetching.
37 """
38 # First three trials should use the same config.
39 trial_1 = exp_data.trials[1]
40 assert trial_1.tunable_config_id == 1
41 assert trial_1.tunable_config_trial_group.tunable_config_trial_group_id == 1
43 trial_2 = exp_data.trials[2]
44 assert trial_2.tunable_config_id == 1
45 assert trial_2.tunable_config_trial_group.tunable_config_trial_group_id == 1
47 # The fourth, should be a new config.
48 trial_4 = exp_data.trials[4]
49 assert trial_4.tunable_config_id == 2
50 assert trial_4.tunable_config_trial_group.tunable_config_trial_group_id == 4
52 # And so on ...
55def test_tunable_config_trial_group_results_df(
56 exp_data: ExperimentData,
57 tunable_groups: TunableGroups,
58) -> None:
59 """Tests the results_df property of the TunableConfigTrialGroup."""
60 tunable_config_id = 2
61 expected_group_id = 4
62 tunable_config_trial_group = exp_data.tunable_config_trial_groups[tunable_config_id]
63 results_df = tunable_config_trial_group.results_df
64 # We shouldn't have the results for the other configs, just this one.
65 expected_count = CONFIG_TRIAL_REPEAT_COUNT
66 assert len(results_df) == expected_count
67 assert (
68 len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) == expected_count
69 )
70 assert len(results_df[(results_df["tunable_config_id"] != tunable_config_id)]) == 0
71 assert (
72 len(results_df[(results_df["tunable_config_trial_group_id"] == expected_group_id)])
73 == expected_count
74 )
75 assert len(results_df[(results_df["tunable_config_trial_group_id"] != expected_group_id)]) == 0
76 assert len(results_df["trial_id"].unique()) == expected_count
77 obj_target = next(iter(exp_data.objectives))
78 assert len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_count
79 (tunable, _covariant_group) = next(iter(tunable_groups))
80 assert len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) == expected_count
83def test_tunable_config_trial_group_trials(exp_data: ExperimentData) -> None:
84 """Tests the trials property of the TunableConfigTrialGroup."""
85 tunable_config_id = 2
86 expected_group_id = 4
87 tunable_config_trial_group = exp_data.tunable_config_trial_groups[tunable_config_id]
88 trials = tunable_config_trial_group.trials
89 assert len(trials) == CONFIG_TRIAL_REPEAT_COUNT
90 assert all(
91 trial.tunable_config_trial_group.tunable_config_trial_group_id == expected_group_id
92 for trial in trials.values()
93 )
94 assert all(
95 trial.tunable_config_id == tunable_config_id
96 for trial in tunable_config_trial_group.trials.values()
97 )
98 assert (
99 exp_data.trials[expected_group_id] == tunable_config_trial_group.trials[expected_group_id]
100 )