Coverage for mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py: 100%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for scheduling trials for some future time.""" 

6from collections.abc import Iterator 

7from datetime import datetime, timedelta 

8 

9from pytz import UTC 

10 

11from mlos_bench.environments.status import Status 

12from mlos_bench.storage.base_experiment_data import ExperimentData 

13from mlos_bench.storage.base_storage import Storage 

14from mlos_bench.tests.storage import ( 

15 CONFIG_COUNT, 

16 CONFIG_TRIAL_REPEAT_COUNT, 

17 TRIAL_RUNNER_COUNT, 

18) 

19from mlos_bench.tunables.tunable_groups import TunableGroups 

20 

21 

22def _trial_ids(trials: Iterator[Storage.Trial]) -> set[int]: 

23 """Extract trial IDs from a list of trials.""" 

24 return {t.trial_id for t in trials} 

25 

26 

27def test_schedule_trial( 

28 storage: Storage, 

29 exp_storage: Storage.Experiment, 

30 tunable_groups: TunableGroups, 

31) -> None: 

32 # pylint: disable=too-many-locals 

33 """Schedule several trials for future execution and retrieve them later at certain 

34 timestamps. 

35 """ 

36 timestamp = datetime.now(UTC) 

37 timedelta_1min = timedelta(minutes=1) 

38 timedelta_1hr = timedelta(hours=1) 

39 config = {"location": "westus2", "num_repeats": 10} 

40 

41 # Default, schedule now: 

42 trial_now1 = exp_storage.new_trial(tunable_groups, config=config) 

43 # Schedule with explicit current timestamp: 

44 trial_now2 = exp_storage.new_trial(tunable_groups, timestamp, config) 

45 # Schedule 1 hour in the future: 

46 trial_1h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr, config) 

47 # Schedule 2 hours in the future: 

48 trial_2h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr * 2, config) 

49 

50 exp_data = storage.experiments[exp_storage.experiment_id] 

51 trial_now1_data = exp_data.trials[trial_now1.trial_id] 

52 assert trial_now1_data.trial_runner_id is None 

53 assert trial_now1_data.status == Status.PENDING 

54 # Check that Status matches in object vs. backend storage. 

55 assert trial_now1.status == trial_now1_data.status 

56 

57 # Scheduler side: get trials ready to run at certain timestamps: 

58 

59 # Pretend 1 minute has passed, get trials scheduled to run: 

60 pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1min, running=False)) 

61 assert pending_ids == { 

62 trial_now1.trial_id, 

63 trial_now2.trial_id, 

64 } 

65 

66 # Get trials scheduled to run within the next 1 hour: 

67 pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1hr, running=False)) 

68 assert pending_ids == { 

69 trial_now1.trial_id, 

70 trial_now2.trial_id, 

71 trial_1h.trial_id, 

72 } 

73 

74 # Get trials scheduled to run within the next 3 hours: 

75 pending_ids = _trial_ids( 

76 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) 

77 ) 

78 assert pending_ids == { 

79 trial_now1.trial_id, 

80 trial_now2.trial_id, 

81 trial_1h.trial_id, 

82 trial_2h.trial_id, 

83 } 

84 

85 # Optimizer side: get trials completed after some known trial: 

86 

87 # No completed trials yet: 

88 assert exp_storage.load() == ([], [], [], []) 

89 

90 # Update the status of some trials: 

91 trial_now1.update(Status.RUNNING, timestamp + timedelta_1min) 

92 trial_now2.update(Status.RUNNING, timestamp + timedelta_1min) 

93 

94 # Still no completed trials: 

95 assert exp_storage.load() == ([], [], [], []) 

96 

97 # Get trials scheduled to run within the next 3 hours: 

98 pending_ids = _trial_ids( 

99 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) 

100 ) 

101 assert pending_ids == { 

102 trial_1h.trial_id, 

103 trial_2h.trial_id, 

104 } 

105 

106 # Get trials scheduled to run OR running within the next 3 hours: 

107 pending_ids = _trial_ids( 

108 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True) 

109 ) 

110 assert pending_ids == { 

111 trial_now1.trial_id, 

112 trial_now2.trial_id, 

113 trial_1h.trial_id, 

114 trial_2h.trial_id, 

115 } 

116 

117 # Mark some trials completed after 2 minutes: 

118 trial_now1.update(Status.SUCCEEDED, timestamp + timedelta_1min * 2, metrics={"score": 1.0}) 

119 trial_now2.update(Status.FAILED, timestamp + timedelta_1min * 2) 

120 

121 # Another one completes after 2 hours: 

122 trial_1h.update(Status.SUCCEEDED, timestamp + timedelta_1hr * 2, metrics={"score": 1.0}) 

123 

124 # Check that three trials have completed so far: 

125 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load() 

126 assert trial_ids == [trial_now1.trial_id, trial_now2.trial_id, trial_1h.trial_id] 

127 assert len(trial_configs) == len(trial_scores) == 3 

128 assert trial_status == [Status.SUCCEEDED, Status.FAILED, Status.SUCCEEDED] 

129 

130 # Get only trials completed after trial_now2: 

131 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load( 

132 last_trial_id=trial_now2.trial_id 

133 ) 

134 assert trial_ids == [trial_1h.trial_id] 

135 assert len(trial_configs) == len(trial_scores) == 1 

136 assert trial_status == [Status.SUCCEEDED] 

137 

138 

139def test_rr_scheduling(exp_data: ExperimentData) -> None: 

140 """Checks that the scheduler produced basic round-robin scheduling of Trials across 

141 TrialRunners. 

142 """ 

143 for trial_id in range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT + 1): 

144 # User visible IDs start from 1. 

145 expected_config_id = (trial_id - 1) // CONFIG_TRIAL_REPEAT_COUNT + 1 

146 expected_repeat_num = (trial_id - 1) % CONFIG_TRIAL_REPEAT_COUNT + 1 

147 expected_runner_id = (trial_id - 1) % TRIAL_RUNNER_COUNT + 1 

148 trial = exp_data.trials[trial_id] 

149 assert trial.trial_id == trial_id, f"Expected trial_id {trial_id} for {trial}" 

150 assert ( 

151 trial.tunable_config_id == expected_config_id 

152 ), f"Expected tunable_config_id {expected_config_id} for {trial}" 

153 assert ( 

154 trial.metadata_dict["repeat_i"] == expected_repeat_num 

155 ), f"Expected repeat_i {expected_repeat_num} for {trial}" 

156 assert ( 

157 trial.trial_runner_id == expected_runner_id 

158 ), f"Expected trial_runner_id {expected_runner_id} for {trial}"