Coverage for mlos_bench/mlos_bench/tests/storage/trial_config_test.py: 100%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for saving and retrieving additional parameters of pending trials.""" 

6from datetime import datetime 

7from typing import Any 

8 

9import pytest 

10from pytz import UTC 

11 

12from mlos_bench.storage.base_storage import Storage 

13from mlos_bench.tunables.tunable_groups import TunableGroups 

14 

15 

16def test_exp_trial_pending(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: 

17 """Schedule a trial and check that it is pending and has the right configuration.""" 

18 config = {"location": "westus2", "num_repeats": 100} 

19 trial = exp_storage.new_trial(tunable_groups, config=config) 

20 (pending,) = list(exp_storage.pending_trials(datetime.now(UTC), running=True)) 

21 assert pending.trial_id == trial.trial_id 

22 assert pending.tunables == tunable_groups 

23 assert pending.config() == { 

24 "location": "westus2", 

25 "num_repeats": "100", 

26 "experiment_id": "Test-001", 

27 "trial_id": trial.trial_id, 

28 } 

29 

30 

31def test_add_new_trial_config_data( 

32 storage: Storage, 

33 exp_storage: Storage.Experiment, 

34 tunable_groups: TunableGroups, 

35) -> None: 

36 """Create a trial and check that adding new data to the config is allowed.""" 

37 config = {"location": "westus2", "num_repeats": 100} 

38 trial = exp_storage.new_trial(tunable_groups, config=config) 

39 new_config = {"new_key": "new_value"} 

40 trial.add_new_config_data(new_config) 

41 trial_data = storage.experiments[exp_storage.experiment_id].trials[trial.trial_id] 

42 assert trial_data.metadata_dict == { 

43 **config, 

44 **new_config, 

45 } 

46 

47 

48def test_add_bad_new_trial_config_data( 

49 storage: Storage, 

50 exp_storage: Storage.Experiment, 

51 tunable_groups: TunableGroups, 

52) -> None: 

53 """Create a trial and check that adding repeated data to the config is 

54 disallowed. 

55 """ 

56 config = {"location": "westus2", "num_repeats": 100} 

57 trial = exp_storage.new_trial(tunable_groups, config=config) 

58 new_config = {"location": "eastus2"} 

59 with pytest.raises(ValueError): 

60 trial.add_new_config_data(new_config) 

61 trial_data = storage.experiments[exp_storage.experiment_id].trials[trial.trial_id] 

62 assert trial_data.metadata_dict == config 

63 

64 

65def test_exp_trial_configs(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: 

66 """Start multiple trials with two different configs and check that we store only two 

67 config objects in the DB. 

68 """ 

69 config1 = tunable_groups.copy().assign({"idle": "mwait"}) 

70 trials1 = [ 

71 exp_storage.new_trial(config1), 

72 exp_storage.new_trial(config1), 

73 exp_storage.new_trial(config1.copy()), # Same values, different instance 

74 ] 

75 assert trials1[0].tunable_config_id == trials1[1].tunable_config_id 

76 assert trials1[0].tunable_config_id == trials1[2].tunable_config_id 

77 

78 config2 = tunable_groups.copy().assign({"idle": "halt"}) 

79 trials2 = [ 

80 exp_storage.new_trial(config2), 

81 exp_storage.new_trial(config2), 

82 exp_storage.new_trial(config2.copy()), # Same values, different instance 

83 ] 

84 assert trials2[0].tunable_config_id == trials2[1].tunable_config_id 

85 assert trials2[0].tunable_config_id == trials2[2].tunable_config_id 

86 

87 assert trials1[0].tunable_config_id != trials2[0].tunable_config_id 

88 

89 pending_ids = [ 

90 pending.tunable_config_id 

91 for pending in exp_storage.pending_trials(datetime.now(UTC), running=True) 

92 ] 

93 assert len(pending_ids) == 6 

94 assert len(set(pending_ids)) == 2 

95 assert set(pending_ids) == {trials1[0].tunable_config_id, trials2[0].tunable_config_id} 

96 

97 

98def test_exp_trial_no_config(exp_no_tunables_storage: Storage.Experiment) -> None: 

99 """Schedule a trial that has an empty tunable groups config.""" 

100 empty_config: dict = {} 

101 tunable_groups = TunableGroups(config=empty_config) 

102 trial = exp_no_tunables_storage.new_trial(tunable_groups, config=empty_config) 

103 (pending,) = exp_no_tunables_storage.pending_trials(datetime.now(UTC), running=True) 

104 assert pending.trial_id == trial.trial_id 

105 assert pending.tunables == tunable_groups 

106 assert pending.config() == { 

107 "experiment_id": "Test-003", 

108 "trial_id": trial.trial_id, 

109 } 

110 

111 

112@pytest.mark.parametrize( 

113 "bad_config", 

114 [ 

115 { 

116 "obj": object(), 

117 }, 

118 { 

119 "callable": lambda x: x, 

120 }, 

121 { 

122 "nested": { 

123 "callable": lambda x: x, 

124 }, 

125 }, 

126 ], 

127) 

128def test_exp_trial_non_serializable_config( 

129 exp_no_tunables_storage: Storage.Experiment, 

130 bad_config: dict[str, Any], 

131) -> None: 

132 """Tests that a trial with a non-serializable config is rejected.""" 

133 empty_config: dict = {} 

134 tunable_groups = TunableGroups(config=empty_config) 

135 with pytest.raises(ValueError): 

136 exp_no_tunables_storage.new_trial(tunable_groups, config=bad_config)