Coverage for mlos_bench/mlos_bench/tests/storage/trial_config_test.py: 100%
41 statements
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
« prev ^ index » next coverage.py v7.6.9, created at 2024-12-14 01:58 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Unit tests for saving and retrieving additional parameters of pending trials."""
6from datetime import datetime
7from typing import Any, Dict
9import pytest
10from pytz import UTC
12from mlos_bench.storage.base_storage import Storage
13from mlos_bench.tunables.tunable_groups import TunableGroups
16def test_exp_trial_pending(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None:
17 """Schedule a trial and check that it is pending and has the right configuration."""
18 config = {"location": "westus2", "num_repeats": 100}
19 trial = exp_storage.new_trial(tunable_groups, config=config)
20 (pending,) = list(exp_storage.pending_trials(datetime.now(UTC), running=True))
21 assert pending.trial_id == trial.trial_id
22 assert pending.tunables == tunable_groups
23 assert pending.config() == {
24 "location": "westus2",
25 "num_repeats": "100",
26 "experiment_id": "Test-001",
27 "trial_id": trial.trial_id,
28 }
31def test_exp_trial_configs(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None:
32 """Start multiple trials with two different configs and check that we store only two
33 config objects in the DB.
34 """
35 config1 = tunable_groups.copy().assign({"idle": "mwait"})
36 trials1 = [
37 exp_storage.new_trial(config1),
38 exp_storage.new_trial(config1),
39 exp_storage.new_trial(config1.copy()), # Same values, different instance
40 ]
41 assert trials1[0].tunable_config_id == trials1[1].tunable_config_id
42 assert trials1[0].tunable_config_id == trials1[2].tunable_config_id
44 config2 = tunable_groups.copy().assign({"idle": "halt"})
45 trials2 = [
46 exp_storage.new_trial(config2),
47 exp_storage.new_trial(config2),
48 exp_storage.new_trial(config2.copy()), # Same values, different instance
49 ]
50 assert trials2[0].tunable_config_id == trials2[1].tunable_config_id
51 assert trials2[0].tunable_config_id == trials2[2].tunable_config_id
53 assert trials1[0].tunable_config_id != trials2[0].tunable_config_id
55 pending_ids = [
56 pending.tunable_config_id
57 for pending in exp_storage.pending_trials(datetime.now(UTC), running=True)
58 ]
59 assert len(pending_ids) == 6
60 assert len(set(pending_ids)) == 2
61 assert set(pending_ids) == {trials1[0].tunable_config_id, trials2[0].tunable_config_id}
64def test_exp_trial_no_config(exp_no_tunables_storage: Storage.Experiment) -> None:
65 """Schedule a trial that has an empty tunable groups config."""
66 empty_config: dict = {}
67 tunable_groups = TunableGroups(config=empty_config)
68 trial = exp_no_tunables_storage.new_trial(tunable_groups, config=empty_config)
69 (pending,) = exp_no_tunables_storage.pending_trials(datetime.now(UTC), running=True)
70 assert pending.trial_id == trial.trial_id
71 assert pending.tunables == tunable_groups
72 assert pending.config() == {
73 "experiment_id": "Test-003",
74 "trial_id": trial.trial_id,
75 }
78@pytest.mark.parametrize(
79 "bad_config",
80 [
81 {
82 "obj": object(),
83 },
84 {
85 "callable": lambda x: x,
86 },
87 {
88 "nested": {
89 "callable": lambda x: x,
90 },
91 },
92 ],
93)
94def test_exp_trial_non_serializable_config(
95 exp_no_tunables_storage: Storage.Experiment,
96 bad_config: Dict[str, Any],
97) -> None:
98 """Tests that a trial with a non-serializable config is rejected."""
99 empty_config: dict = {}
100 tunable_groups = TunableGroups(config=empty_config)
101 with pytest.raises(ValueError):
102 exp_no_tunables_storage.new_trial(tunable_groups, config=bad_config)