Coverage for mlos_bench/mlos_bench/tests/storage/trial_config_test.py: 100%
56 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for saving and retrieving additional parameters of pending trials."""
6from datetime import datetime
7from typing import Any
9import pytest
10from pytz import UTC
12from mlos_bench.storage.base_storage import Storage
13from mlos_bench.tunables.tunable_groups import TunableGroups
16def test_exp_trial_pending(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None:
17 """Schedule a trial and check that it is pending and has the right configuration."""
18 config = {"location": "westus2", "num_repeats": 100}
19 trial = exp_storage.new_trial(tunable_groups, config=config)
20 (pending,) = list(exp_storage.pending_trials(datetime.now(UTC), running=True))
21 assert pending.trial_id == trial.trial_id
22 assert pending.tunables == tunable_groups
23 assert pending.config() == {
24 "location": "westus2",
25 "num_repeats": "100",
26 "experiment_id": "Test-001",
27 "trial_id": trial.trial_id,
28 }
31def test_add_new_trial_config_data(
32 storage: Storage,
33 exp_storage: Storage.Experiment,
34 tunable_groups: TunableGroups,
35) -> None:
36 """Create a trial and check that adding new data to the config is allowed."""
37 config = {"location": "westus2", "num_repeats": 100}
38 trial = exp_storage.new_trial(tunable_groups, config=config)
39 new_config = {"new_key": "new_value"}
40 trial.add_new_config_data(new_config)
41 trial_data = storage.experiments[exp_storage.experiment_id].trials[trial.trial_id]
42 assert trial_data.metadata_dict == {
43 **config,
44 **new_config,
45 }
48def test_add_bad_new_trial_config_data(
49 storage: Storage,
50 exp_storage: Storage.Experiment,
51 tunable_groups: TunableGroups,
52) -> None:
53 """Create a trial and check that adding repeated data to the config is
54 disallowed.
55 """
56 config = {"location": "westus2", "num_repeats": 100}
57 trial = exp_storage.new_trial(tunable_groups, config=config)
58 new_config = {"location": "eastus2"}
59 with pytest.raises(ValueError):
60 trial.add_new_config_data(new_config)
61 trial_data = storage.experiments[exp_storage.experiment_id].trials[trial.trial_id]
62 assert trial_data.metadata_dict == config
65def test_exp_trial_configs(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None:
66 """Start multiple trials with two different configs and check that we store only two
67 config objects in the DB.
68 """
69 config1 = tunable_groups.copy().assign({"idle": "mwait"})
70 trials1 = [
71 exp_storage.new_trial(config1),
72 exp_storage.new_trial(config1),
73 exp_storage.new_trial(config1.copy()), # Same values, different instance
74 ]
75 assert trials1[0].tunable_config_id == trials1[1].tunable_config_id
76 assert trials1[0].tunable_config_id == trials1[2].tunable_config_id
78 config2 = tunable_groups.copy().assign({"idle": "halt"})
79 trials2 = [
80 exp_storage.new_trial(config2),
81 exp_storage.new_trial(config2),
82 exp_storage.new_trial(config2.copy()), # Same values, different instance
83 ]
84 assert trials2[0].tunable_config_id == trials2[1].tunable_config_id
85 assert trials2[0].tunable_config_id == trials2[2].tunable_config_id
87 assert trials1[0].tunable_config_id != trials2[0].tunable_config_id
89 pending_ids = [
90 pending.tunable_config_id
91 for pending in exp_storage.pending_trials(datetime.now(UTC), running=True)
92 ]
93 assert len(pending_ids) == 6
94 assert len(set(pending_ids)) == 2
95 assert set(pending_ids) == {trials1[0].tunable_config_id, trials2[0].tunable_config_id}
98def test_exp_trial_no_config(exp_no_tunables_storage: Storage.Experiment) -> None:
99 """Schedule a trial that has an empty tunable groups config."""
100 empty_config: dict = {}
101 tunable_groups = TunableGroups(config=empty_config)
102 trial = exp_no_tunables_storage.new_trial(tunable_groups, config=empty_config)
103 (pending,) = exp_no_tunables_storage.pending_trials(datetime.now(UTC), running=True)
104 assert pending.trial_id == trial.trial_id
105 assert pending.tunables == tunable_groups
106 assert pending.config() == {
107 "experiment_id": "Test-003",
108 "trial_id": trial.trial_id,
109 }
112@pytest.mark.parametrize(
113 "bad_config",
114 [
115 {
116 "obj": object(),
117 },
118 {
119 "callable": lambda x: x,
120 },
121 {
122 "nested": {
123 "callable": lambda x: x,
124 },
125 },
126 ],
127)
128def test_exp_trial_non_serializable_config(
129 exp_no_tunables_storage: Storage.Experiment,
130 bad_config: dict[str, Any],
131) -> None:
132 """Tests that a trial with a non-serializable config is rejected."""
133 empty_config: dict = {}
134 tunable_groups = TunableGroups(config=empty_config)
135 with pytest.raises(ValueError):
136 exp_no_tunables_storage.new_trial(tunable_groups, config=bad_config)