Coverage for mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py: 100%
23 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for loading the TunableConfigData."""
7from mlos_bench.storage.base_experiment_data import ExperimentData
8from mlos_bench.tunables.tunable_groups import TunableGroups
11def test_trial_data_tunable_config_data(
12 exp_data: ExperimentData,
13 tunable_groups: TunableGroups,
14) -> None:
15 """Check expected return values for TunableConfigData."""
16 trial_id = 1
17 expected_config_id = 1
18 trial = exp_data.trials[trial_id]
19 tunable_config = trial.tunable_config
20 assert tunable_config.tunable_config_id == expected_config_id
21 # The first should be the defaults.
22 assert tunable_config.config_dict == tunable_groups.get_param_values()
23 assert trial.tunable_config_trial_group.tunable_config == tunable_config
26def test_trial_metadata(exp_data: ExperimentData) -> None:
27 """Check expected return values for TunableConfigData metadata."""
28 assert exp_data.objectives == {"score": "min"}
29 for trial_id, trial in exp_data.trials.items():
30 assert trial.metadata_dict == {
31 "opt_target_0": "score",
32 "opt_direction_0": "min",
33 "trial_number": trial_id,
34 }
37def test_trial_data_no_tunables_config_data(exp_no_tunables_data: ExperimentData) -> None:
38 """Check expected return values for TunableConfigData."""
39 empty_config: dict = {}
40 for _trial_id, trial in exp_no_tunables_data.trials.items():
41 assert trial.tunable_config.config_dict == empty_config
44def test_mixed_numerics_exp_trial_data(
45 mixed_numerics_exp_data: ExperimentData,
46 mixed_numerics_tunable_groups: TunableGroups,
47) -> None:
48 """Tests that data type conversions are retained when loading experiment data with
49 mixed numeric tunable types.
50 """
51 trial = next(iter(mixed_numerics_exp_data.trials.values()))
52 config = trial.tunable_config.config_dict
53 for tunable, _group in mixed_numerics_tunable_groups:
54 assert isinstance(config[tunable.name], tunable.dtype)