Coverage for mlos_bench/mlos_bench/tests/storage/sql/fixtures.py: 100%
54 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Test fixtures for mlos_bench storage."""
7from collections.abc import Generator
8from random import seed as rand_seed
10import pytest
12from mlos_bench.optimizers.mock_optimizer import MockOptimizer
13from mlos_bench.schedulers.sync_scheduler import SyncScheduler
14from mlos_bench.schedulers.trial_runner import TrialRunner
15from mlos_bench.services.config_persistence import ConfigPersistenceService
16from mlos_bench.storage.base_experiment_data import ExperimentData
17from mlos_bench.storage.sql.storage import SqlStorage
18from mlos_bench.tests import SEED
19from mlos_bench.tests.storage import (
20 CONFIG_TRIAL_REPEAT_COUNT,
21 MAX_TRIALS,
22 TRIAL_RUNNER_COUNT,
23)
24from mlos_bench.tunables.tunable_groups import TunableGroups
26# pylint: disable=redefined-outer-name
29@pytest.fixture
30def storage() -> SqlStorage:
31 """Test fixture for in-memory SQLite3 storage."""
32 return SqlStorage(
33 service=None,
34 config={
35 "drivername": "sqlite",
36 "database": ":memory:",
37 # "database": "mlos_bench.pytest.db",
38 },
39 )
42@pytest.fixture
43def exp_storage(
44 storage: SqlStorage,
45 tunable_groups: TunableGroups,
46) -> Generator[SqlStorage.Experiment]:
47 """
48 Test fixture for Experiment using in-memory SQLite3 storage.
50 Note: It has already entered the context upon return.
51 """
52 with storage.experiment(
53 experiment_id="Test-001",
54 trial_id=1,
55 root_env_config="environment.jsonc",
56 description="pytest experiment",
57 tunables=tunable_groups,
58 opt_targets={"score": "min"},
59 ) as exp:
60 yield exp
61 # pylint: disable=protected-access
62 assert not exp._in_context
65@pytest.fixture
66def exp_no_tunables_storage(
67 storage: SqlStorage,
68) -> Generator[SqlStorage.Experiment]:
69 """
70 Test fixture for Experiment using in-memory SQLite3 storage.
72 Note: It has already entered the context upon return.
73 """
74 empty_config: dict = {}
75 with storage.experiment(
76 experiment_id="Test-003",
77 trial_id=1,
78 root_env_config="environment.jsonc",
79 description="pytest experiment - no tunables",
80 tunables=TunableGroups(empty_config),
81 opt_targets={"score": "min"},
82 ) as exp:
83 yield exp
84 # pylint: disable=protected-access
85 assert not exp._in_context
88@pytest.fixture
89def mixed_numerics_exp_storage(
90 storage: SqlStorage,
91 mixed_numerics_tunable_groups: TunableGroups,
92) -> Generator[SqlStorage.Experiment]:
93 """
94 Test fixture for an Experiment with mixed numerics tunables using in-memory SQLite3
95 storage.
97 Note: It has already entered the context upon return.
98 """
99 with storage.experiment(
100 experiment_id="Test-002",
101 trial_id=1,
102 root_env_config="dne.jsonc",
103 description="pytest experiment",
104 tunables=mixed_numerics_tunable_groups,
105 opt_targets={"score": "min"},
106 ) as exp:
107 yield exp
108 # pylint: disable=protected-access
109 assert not exp._in_context
112def _dummy_run_exp(
113 storage: SqlStorage,
114 exp: SqlStorage.Experiment,
115) -> ExperimentData:
116 """
117 Generates data by doing a simulated run of the given experiment.
119 Parameters
120 ----------
121 storage : SqlStorage
122 The storage object to use.
123 exp : SqlStorage.Experiment
124 The experiment to "run".
125 Note: this particular object won't be updated, but a new one will be created
126 from its metadata.
128 Returns
129 -------
130 ExperimentData
131 The data generated by the simulated run.
132 """
133 # pylint: disable=too-many-locals
135 rand_seed(SEED)
137 trial_runners: list[TrialRunner] = []
138 global_config: dict = {}
139 config_loader = ConfigPersistenceService()
140 tunable_params = ",".join(f'"{name}"' for name in exp.tunables.get_covariant_group_names())
141 mock_env_json = f"""
142 {
143 "class": "mlos_bench.environments.mock_env.MockEnv",
144 "name": "Test Env",
145 "config": {
146 "tunable_params": [{tunable_params}],
147 "mock_env_seed": {SEED},
148 "mock_env_range": [60, 120],
149 "mock_env_metrics": ["score"]
150 }
151 }
152 """
153 trial_runners = TrialRunner.create_from_json(
154 config_loader=config_loader,
155 global_config=global_config,
156 tunable_groups=exp.tunables,
157 env_json=mock_env_json,
158 svcs_json=None,
159 num_trial_runners=TRIAL_RUNNER_COUNT,
160 )
162 opt = MockOptimizer(
163 tunables=exp.tunables,
164 config={
165 "optimization_targets": exp.opt_targets,
166 "seed": SEED,
167 # This should be the default, so we leave it omitted for now to test the default.
168 # But the test logic relies on this (e.g., trial 1 is config 1 is the
169 # default values for the tunable params)
170 # "start_with_defaults": True,
171 "max_suggestions": MAX_TRIALS,
172 },
173 global_config=global_config,
174 )
176 scheduler = SyncScheduler(
177 # All config values can be overridden from global config
178 config={
179 "experiment_id": exp.experiment_id,
180 "trial_id": exp.trial_id,
181 "config_id": -1,
182 "trial_config_repeat_count": CONFIG_TRIAL_REPEAT_COUNT,
183 "max_trials": MAX_TRIALS,
184 },
185 global_config=global_config,
186 trial_runners=trial_runners,
187 optimizer=opt,
188 storage=storage,
189 root_env_config=exp.root_env_config,
190 )
192 # Add some trial data to that experiment by "running" it.
193 with scheduler:
194 scheduler.start()
195 scheduler.teardown()
197 return storage.experiments[exp.experiment_id]
200@pytest.fixture
201def exp_data(
202 storage: SqlStorage,
203 exp_storage: SqlStorage.Experiment,
204) -> ExperimentData:
205 """Test fixture for ExperimentData."""
206 return _dummy_run_exp(storage, exp_storage)
209@pytest.fixture
210def exp_no_tunables_data(
211 storage: SqlStorage,
212 exp_no_tunables_storage: SqlStorage.Experiment,
213) -> ExperimentData:
214 """Test fixture for ExperimentData with no tunable configs."""
215 return _dummy_run_exp(storage, exp_no_tunables_storage)
218@pytest.fixture
219def mixed_numerics_exp_data(
220 storage: SqlStorage,
221 mixed_numerics_exp_storage: SqlStorage.Experiment,
222) -> ExperimentData:
223 """Test fixture for ExperimentData with mixed numerical tunable types."""
224 return _dummy_run_exp(storage, mixed_numerics_exp_storage)