Coverage for mlos_bench/mlos_bench/tests/storage/sql/fixtures.py: 100%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Test fixtures for mlos_bench storage.""" 

6 

7from collections.abc import Generator 

8from random import seed as rand_seed 

9 

10import pytest 

11 

12from mlos_bench.optimizers.mock_optimizer import MockOptimizer 

13from mlos_bench.schedulers.sync_scheduler import SyncScheduler 

14from mlos_bench.schedulers.trial_runner import TrialRunner 

15from mlos_bench.services.config_persistence import ConfigPersistenceService 

16from mlos_bench.storage.base_experiment_data import ExperimentData 

17from mlos_bench.storage.sql.storage import SqlStorage 

18from mlos_bench.tests import SEED 

19from mlos_bench.tests.storage import ( 

20 CONFIG_TRIAL_REPEAT_COUNT, 

21 MAX_TRIALS, 

22 TRIAL_RUNNER_COUNT, 

23) 

24from mlos_bench.tunables.tunable_groups import TunableGroups 

25 

26# pylint: disable=redefined-outer-name 

27 

28 

29@pytest.fixture 

30def storage() -> SqlStorage: 

31 """Test fixture for in-memory SQLite3 storage.""" 

32 return SqlStorage( 

33 service=None, 

34 config={ 

35 "drivername": "sqlite", 

36 "database": ":memory:", 

37 # "database": "mlos_bench.pytest.db", 

38 }, 

39 ) 

40 

41 

42@pytest.fixture 

43def exp_storage( 

44 storage: SqlStorage, 

45 tunable_groups: TunableGroups, 

46) -> Generator[SqlStorage.Experiment]: 

47 """ 

48 Test fixture for Experiment using in-memory SQLite3 storage. 

49 

50 Note: It has already entered the context upon return. 

51 """ 

52 with storage.experiment( 

53 experiment_id="Test-001", 

54 trial_id=1, 

55 root_env_config="environment.jsonc", 

56 description="pytest experiment", 

57 tunables=tunable_groups, 

58 opt_targets={"score": "min"}, 

59 ) as exp: 

60 yield exp 

61 # pylint: disable=protected-access 

62 assert not exp._in_context 

63 

64 

65@pytest.fixture 

66def exp_no_tunables_storage( 

67 storage: SqlStorage, 

68) -> Generator[SqlStorage.Experiment]: 

69 """ 

70 Test fixture for Experiment using in-memory SQLite3 storage. 

71 

72 Note: It has already entered the context upon return. 

73 """ 

74 empty_config: dict = {} 

75 with storage.experiment( 

76 experiment_id="Test-003", 

77 trial_id=1, 

78 root_env_config="environment.jsonc", 

79 description="pytest experiment - no tunables", 

80 tunables=TunableGroups(empty_config), 

81 opt_targets={"score": "min"}, 

82 ) as exp: 

83 yield exp 

84 # pylint: disable=protected-access 

85 assert not exp._in_context 

86 

87 

88@pytest.fixture 

89def mixed_numerics_exp_storage( 

90 storage: SqlStorage, 

91 mixed_numerics_tunable_groups: TunableGroups, 

92) -> Generator[SqlStorage.Experiment]: 

93 """ 

94 Test fixture for an Experiment with mixed numerics tunables using in-memory SQLite3 

95 storage. 

96 

97 Note: It has already entered the context upon return. 

98 """ 

99 with storage.experiment( 

100 experiment_id="Test-002", 

101 trial_id=1, 

102 root_env_config="dne.jsonc", 

103 description="pytest experiment", 

104 tunables=mixed_numerics_tunable_groups, 

105 opt_targets={"score": "min"}, 

106 ) as exp: 

107 yield exp 

108 # pylint: disable=protected-access 

109 assert not exp._in_context 

110 

111 

112def _dummy_run_exp( 

113 storage: SqlStorage, 

114 exp: SqlStorage.Experiment, 

115) -> ExperimentData: 

116 """ 

117 Generates data by doing a simulated run of the given experiment. 

118 

119 Parameters 

120 ---------- 

121 storage : SqlStorage 

122 The storage object to use. 

123 exp : SqlStorage.Experiment 

124 The experiment to "run". 

125 Note: this particular object won't be updated, but a new one will be created 

126 from its metadata. 

127 

128 Returns 

129 ------- 

130 ExperimentData 

131 The data generated by the simulated run. 

132 """ 

133 # pylint: disable=too-many-locals 

134 

135 rand_seed(SEED) 

136 

137 trial_runners: list[TrialRunner] = [] 

138 global_config: dict = {} 

139 config_loader = ConfigPersistenceService() 

140 tunable_params = ",".join(f'"{name}"' for name in exp.tunables.get_covariant_group_names()) 

141 mock_env_json = f""" 

142 { 

143 "class": "mlos_bench.environments.mock_env.MockEnv", 

144 "name": "Test Env", 

145 "config": { 

146 "tunable_params": [{tunable_params}], 

147 "mock_env_seed": {SEED}, 

148 "mock_env_range": [60, 120], 

149 "mock_env_metrics": ["score"] 

150 } 

151 } 

152 """ 

153 trial_runners = TrialRunner.create_from_json( 

154 config_loader=config_loader, 

155 global_config=global_config, 

156 tunable_groups=exp.tunables, 

157 env_json=mock_env_json, 

158 svcs_json=None, 

159 num_trial_runners=TRIAL_RUNNER_COUNT, 

160 ) 

161 

162 opt = MockOptimizer( 

163 tunables=exp.tunables, 

164 config={ 

165 "optimization_targets": exp.opt_targets, 

166 "seed": SEED, 

167 # This should be the default, so we leave it omitted for now to test the default. 

168 # But the test logic relies on this (e.g., trial 1 is config 1 is the 

169 # default values for the tunable params) 

170 # "start_with_defaults": True, 

171 "max_suggestions": MAX_TRIALS, 

172 }, 

173 global_config=global_config, 

174 ) 

175 

176 scheduler = SyncScheduler( 

177 # All config values can be overridden from global config 

178 config={ 

179 "experiment_id": exp.experiment_id, 

180 "trial_id": exp.trial_id, 

181 "config_id": -1, 

182 "trial_config_repeat_count": CONFIG_TRIAL_REPEAT_COUNT, 

183 "max_trials": MAX_TRIALS, 

184 }, 

185 global_config=global_config, 

186 trial_runners=trial_runners, 

187 optimizer=opt, 

188 storage=storage, 

189 root_env_config=exp.root_env_config, 

190 ) 

191 

192 # Add some trial data to that experiment by "running" it. 

193 with scheduler: 

194 scheduler.start() 

195 scheduler.teardown() 

196 

197 return storage.experiments[exp.experiment_id] 

198 

199 

200@pytest.fixture 

201def exp_data( 

202 storage: SqlStorage, 

203 exp_storage: SqlStorage.Experiment, 

204) -> ExperimentData: 

205 """Test fixture for ExperimentData.""" 

206 return _dummy_run_exp(storage, exp_storage) 

207 

208 

209@pytest.fixture 

210def exp_no_tunables_data( 

211 storage: SqlStorage, 

212 exp_no_tunables_storage: SqlStorage.Experiment, 

213) -> ExperimentData: 

214 """Test fixture for ExperimentData with no tunable configs.""" 

215 return _dummy_run_exp(storage, exp_no_tunables_storage) 

216 

217 

218@pytest.fixture 

219def mixed_numerics_exp_data( 

220 storage: SqlStorage, 

221 mixed_numerics_exp_storage: SqlStorage.Experiment, 

222) -> ExperimentData: 

223 """Test fixture for ExperimentData with mixed numerical tunable types.""" 

224 return _dummy_run_exp(storage, mixed_numerics_exp_storage)