Coverage for mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py: 100%
59 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for scheduling trials for some future time."""
6from collections.abc import Iterator
7from datetime import datetime, timedelta
9from pytz import UTC
11from mlos_bench.environments.status import Status
12from mlos_bench.storage.base_experiment_data import ExperimentData
13from mlos_bench.storage.base_storage import Storage
14from mlos_bench.tests.storage import (
15 CONFIG_COUNT,
16 CONFIG_TRIAL_REPEAT_COUNT,
17 TRIAL_RUNNER_COUNT,
18)
19from mlos_bench.tunables.tunable_groups import TunableGroups
22def _trial_ids(trials: Iterator[Storage.Trial]) -> set[int]:
23 """Extract trial IDs from a list of trials."""
24 return {t.trial_id for t in trials}
27def test_schedule_trial(
28 storage: Storage,
29 exp_storage: Storage.Experiment,
30 tunable_groups: TunableGroups,
31) -> None:
32 # pylint: disable=too-many-locals
33 """Schedule several trials for future execution and retrieve them later at certain
34 timestamps.
35 """
36 timestamp = datetime.now(UTC)
37 timedelta_1min = timedelta(minutes=1)
38 timedelta_1hr = timedelta(hours=1)
39 config = {"location": "westus2", "num_repeats": 10}
41 # Default, schedule now:
42 trial_now1 = exp_storage.new_trial(tunable_groups, config=config)
43 # Schedule with explicit current timestamp:
44 trial_now2 = exp_storage.new_trial(tunable_groups, timestamp, config)
45 # Schedule 1 hour in the future:
46 trial_1h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr, config)
47 # Schedule 2 hours in the future:
48 trial_2h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr * 2, config)
50 exp_data = storage.experiments[exp_storage.experiment_id]
51 trial_now1_data = exp_data.trials[trial_now1.trial_id]
52 assert trial_now1_data.trial_runner_id is None
53 assert trial_now1_data.status == Status.PENDING
54 # Check that Status matches in object vs. backend storage.
55 assert trial_now1.status == trial_now1_data.status
57 # Scheduler side: get trials ready to run at certain timestamps:
59 # Pretend 1 minute has passed, get trials scheduled to run:
60 pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1min, running=False))
61 assert pending_ids == {
62 trial_now1.trial_id,
63 trial_now2.trial_id,
64 }
66 # Get trials scheduled to run within the next 1 hour:
67 pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1hr, running=False))
68 assert pending_ids == {
69 trial_now1.trial_id,
70 trial_now2.trial_id,
71 trial_1h.trial_id,
72 }
74 # Get trials scheduled to run within the next 3 hours:
75 pending_ids = _trial_ids(
76 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)
77 )
78 assert pending_ids == {
79 trial_now1.trial_id,
80 trial_now2.trial_id,
81 trial_1h.trial_id,
82 trial_2h.trial_id,
83 }
85 # Optimizer side: get trials completed after some known trial:
87 # No completed trials yet:
88 assert exp_storage.load() == ([], [], [], [])
90 # Update the status of some trials:
91 trial_now1.update(Status.RUNNING, timestamp + timedelta_1min)
92 trial_now2.update(Status.RUNNING, timestamp + timedelta_1min)
94 # Still no completed trials:
95 assert exp_storage.load() == ([], [], [], [])
97 # Get trials scheduled to run within the next 3 hours:
98 pending_ids = _trial_ids(
99 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)
100 )
101 assert pending_ids == {
102 trial_1h.trial_id,
103 trial_2h.trial_id,
104 }
106 # Get trials scheduled to run OR running within the next 3 hours:
107 pending_ids = _trial_ids(
108 exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True)
109 )
110 assert pending_ids == {
111 trial_now1.trial_id,
112 trial_now2.trial_id,
113 trial_1h.trial_id,
114 trial_2h.trial_id,
115 }
117 # Mark some trials completed after 2 minutes:
118 trial_now1.update(Status.SUCCEEDED, timestamp + timedelta_1min * 2, metrics={"score": 1.0})
119 trial_now2.update(Status.FAILED, timestamp + timedelta_1min * 2)
121 # Another one completes after 2 hours:
122 trial_1h.update(Status.SUCCEEDED, timestamp + timedelta_1hr * 2, metrics={"score": 1.0})
124 # Check that three trials have completed so far:
125 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load()
126 assert trial_ids == [trial_now1.trial_id, trial_now2.trial_id, trial_1h.trial_id]
127 assert len(trial_configs) == len(trial_scores) == 3
128 assert trial_status == [Status.SUCCEEDED, Status.FAILED, Status.SUCCEEDED]
130 # Get only trials completed after trial_now2:
131 (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load(
132 last_trial_id=trial_now2.trial_id
133 )
134 assert trial_ids == [trial_1h.trial_id]
135 assert len(trial_configs) == len(trial_scores) == 1
136 assert trial_status == [Status.SUCCEEDED]
139def test_rr_scheduling(exp_data: ExperimentData) -> None:
140 """Checks that the scheduler produced basic round-robin scheduling of Trials across
141 TrialRunners.
142 """
143 for trial_id in range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT + 1):
144 # User visible IDs start from 1.
145 expected_config_id = (trial_id - 1) // CONFIG_TRIAL_REPEAT_COUNT + 1
146 expected_repeat_num = (trial_id - 1) % CONFIG_TRIAL_REPEAT_COUNT + 1
147 expected_runner_id = (trial_id - 1) % TRIAL_RUNNER_COUNT + 1
148 trial = exp_data.trials[trial_id]
149 assert trial.trial_id == trial_id, f"Expected trial_id {trial_id} for {trial}"
150 assert (
151 trial.tunable_config_id == expected_config_id
152 ), f"Expected tunable_config_id {expected_config_id} for {trial}"
153 assert (
154 trial.metadata_dict["repeat_i"] == expected_repeat_num
155 ), f"Expected repeat_i {expected_repeat_num} for {trial}"
156 assert (
157 trial.trial_runner_id == expected_runner_id
158 ), f"Expected trial_runner_id {expected_runner_id} for {trial}"