Coverage for mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py: 100%

23 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-10-07 01:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for loading the TunableConfigData.""" 

6 

7from mlos_bench.storage.base_experiment_data import ExperimentData 

8from mlos_bench.tunables.tunable_groups import TunableGroups 

9 

10 

11def test_trial_data_tunable_config_data( 

12 exp_data: ExperimentData, 

13 tunable_groups: TunableGroups, 

14) -> None: 

15 """Check expected return values for TunableConfigData.""" 

16 trial_id = 1 

17 expected_config_id = 1 

18 trial = exp_data.trials[trial_id] 

19 tunable_config = trial.tunable_config 

20 assert tunable_config.tunable_config_id == expected_config_id 

21 # The first should be the defaults. 

22 assert tunable_config.config_dict == tunable_groups.get_param_values() 

23 assert trial.tunable_config_trial_group.tunable_config == tunable_config 

24 

25 

26def test_trial_metadata(exp_data: ExperimentData) -> None: 

27 """Check expected return values for TunableConfigData metadata.""" 

28 assert exp_data.objectives == {"score": "min"} 

29 for trial_id, trial in exp_data.trials.items(): 

30 assert trial.metadata_dict == { 

31 "opt_target_0": "score", 

32 "opt_direction_0": "min", 

33 "trial_number": trial_id, 

34 } 

35 

36 

37def test_trial_data_no_tunables_config_data(exp_no_tunables_data: ExperimentData) -> None: 

38 """Check expected return values for TunableConfigData.""" 

39 empty_config: dict = {} 

40 for _trial_id, trial in exp_no_tunables_data.trials.items(): 

41 assert trial.tunable_config.config_dict == empty_config 

42 

43 

44def test_mixed_numerics_exp_trial_data( 

45 mixed_numerics_exp_data: ExperimentData, 

46 mixed_numerics_tunable_groups: TunableGroups, 

47) -> None: 

48 """Tests that data type conversions are retained when loading experiment data with 

49 mixed numeric tunable types. 

50 """ 

51 trial = next(iter(mixed_numerics_exp_data.trials.values())) 

52 config = trial.tunable_config.config_dict 

53 for tunable, _group in mixed_numerics_tunable_groups: 

54 assert isinstance(config[tunable.name], tunable.dtype)