Coverage for mlos_bench/mlos_bench/schedulers/base_scheduler.py: 91%
130 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-10-07 01:52 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Base class for the optimization loop scheduling policies."""
7import json
8import logging
9from abc import ABCMeta, abstractmethod
10from datetime import datetime
11from types import TracebackType
12from typing import Any, Dict, List, Optional, Tuple, Type
14from pytz import UTC
15from typing_extensions import Literal
17from mlos_bench.config.schemas import ConfigSchema
18from mlos_bench.environments.base_environment import Environment
19from mlos_bench.optimizers.base_optimizer import Optimizer
20from mlos_bench.storage.base_storage import Storage
21from mlos_bench.tunables.tunable_groups import TunableGroups
22from mlos_bench.util import merge_parameters
24_LOG = logging.getLogger(__name__)
27class Scheduler(metaclass=ABCMeta):
28 # pylint: disable=too-many-instance-attributes
29 """Base class for the optimization loop scheduling policies."""
31 def __init__( # pylint: disable=too-many-arguments
32 self,
33 *,
34 config: Dict[str, Any],
35 global_config: Dict[str, Any],
36 environment: Environment,
37 optimizer: Optimizer,
38 storage: Storage,
39 root_env_config: str,
40 ):
41 """
42 Create a new instance of the scheduler. The constructor of this and the derived
43 classes is called by the persistence service after reading the class JSON
44 configuration. Other objects like the Environment and Optimizer are provided by
45 the Launcher.
47 Parameters
48 ----------
49 config : dict
50 The configuration for the scheduler.
51 global_config : dict
52 he global configuration for the experiment.
53 environment : Environment
54 The environment to benchmark/optimize.
55 optimizer : Optimizer
56 The optimizer to use.
57 storage : Storage
58 The storage to use.
59 root_env_config : str
60 Path to the root environment configuration.
61 """
62 self.global_config = global_config
63 config = merge_parameters(
64 dest=config.copy(),
65 source=global_config,
66 required_keys=["experiment_id", "trial_id"],
67 )
68 self._validate_json_config(config)
70 self._experiment_id = config["experiment_id"].strip()
71 self._trial_id = int(config["trial_id"])
72 self._config_id = int(config.get("config_id", -1))
73 self._max_trials = int(config.get("max_trials", -1))
74 self._trial_count = 0
76 self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1))
77 if self._trial_config_repeat_count <= 0:
78 raise ValueError(
79 f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}"
80 )
82 self._do_teardown = bool(config.get("teardown", True))
84 self.experiment: Optional[Storage.Experiment] = None
85 self.environment = environment
86 self.optimizer = optimizer
87 self.storage = storage
88 self._root_env_config = root_env_config
89 self._last_trial_id = -1
90 self._ran_trials: List[Storage.Trial] = []
92 _LOG.debug("Scheduler instantiated: %s :: %s", self, config)
94 def _validate_json_config(self, config: dict) -> None:
95 """Reconstructs a basic json config that this class might have been instantiated
96 from in order to validate configs provided outside the file loading
97 mechanism.
98 """
99 json_config: dict = {
100 "class": self.__class__.__module__ + "." + self.__class__.__name__,
101 }
102 if config:
103 json_config["config"] = config.copy()
104 # The json schema does not allow for -1 as a valid value for config_id.
105 # As it is just a default placeholder value, and not required, we can
106 # remove it from the config copy prior to validation safely.
107 config_id = json_config["config"].get("config_id")
108 if config_id is not None and isinstance(config_id, int) and config_id < 0:
109 json_config["config"].pop("config_id")
110 ConfigSchema.SCHEDULER.validate(json_config)
112 @property
113 def trial_config_repeat_count(self) -> int:
114 """Gets the number of trials to run for a given config."""
115 return self._trial_config_repeat_count
117 @property
118 def trial_count(self) -> int:
119 """Gets the current number of trials run for the experiment."""
120 return self._trial_count
122 @property
123 def max_trials(self) -> int:
124 """Gets the maximum number of trials to run for a given experiment, or -1 for no
125 limit.
126 """
127 return self._max_trials
129 def __repr__(self) -> str:
130 """
131 Produce a human-readable version of the Scheduler (mostly for logging).
133 Returns
134 -------
135 string : str
136 A human-readable version of the Scheduler.
137 """
138 return self.__class__.__name__
140 def __enter__(self) -> "Scheduler":
141 """Enter the scheduler's context."""
142 _LOG.debug("Scheduler START :: %s", self)
143 assert self.experiment is None
144 self.environment.__enter__()
145 self.optimizer.__enter__()
146 # Start new or resume the existing experiment. Verify that the
147 # experiment configuration is compatible with the previous runs.
148 # If the `merge` config parameter is present, merge in the data
149 # from other experiments and check for compatibility.
150 self.experiment = self.storage.experiment(
151 experiment_id=self._experiment_id,
152 trial_id=self._trial_id,
153 root_env_config=self._root_env_config,
154 description=self.environment.name,
155 tunables=self.environment.tunable_params,
156 opt_targets=self.optimizer.targets,
157 ).__enter__()
158 return self
160 def __exit__(
161 self,
162 ex_type: Optional[Type[BaseException]],
163 ex_val: Optional[BaseException],
164 ex_tb: Optional[TracebackType],
165 ) -> Literal[False]:
166 """Exit the context of the scheduler."""
167 if ex_val is None:
168 _LOG.debug("Scheduler END :: %s", self)
169 else:
170 assert ex_type and ex_val
171 _LOG.warning("Scheduler END :: %s", self, exc_info=(ex_type, ex_val, ex_tb))
172 assert self.experiment is not None
173 self.experiment.__exit__(ex_type, ex_val, ex_tb)
174 self.optimizer.__exit__(ex_type, ex_val, ex_tb)
175 self.environment.__exit__(ex_type, ex_val, ex_tb)
176 self.experiment = None
177 return False # Do not suppress exceptions
179 @abstractmethod
180 def start(self) -> None:
181 """Start the optimization loop."""
182 assert self.experiment is not None
183 _LOG.info(
184 "START: Experiment: %s Env: %s Optimizer: %s",
185 self.experiment,
186 self.environment,
187 self.optimizer,
188 )
189 if _LOG.isEnabledFor(logging.INFO):
190 _LOG.info("Root Environment:\n%s", self.environment.pprint())
192 if self._config_id > 0:
193 tunables = self.load_config(self._config_id)
194 self.schedule_trial(tunables)
196 def teardown(self) -> None:
197 """
198 Tear down the environment.
200 Call it after the completion of the `.start()` in the scheduler context.
201 """
202 assert self.experiment is not None
203 if self._do_teardown:
204 self.environment.teardown()
206 def get_best_observation(self) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]:
207 """Get the best observation from the optimizer."""
208 (best_score, best_config) = self.optimizer.get_best_observation()
209 _LOG.info("Env: %s best score: %s", self.environment, best_score)
210 return (best_score, best_config)
212 def load_config(self, config_id: int) -> TunableGroups:
213 """Load the existing tunable configuration from the storage."""
214 assert self.experiment is not None
215 tunable_values = self.experiment.load_tunable_config(config_id)
216 tunables = self.environment.tunable_params.assign(tunable_values)
217 _LOG.info("Load config from storage: %d", config_id)
218 if _LOG.isEnabledFor(logging.DEBUG):
219 _LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2))
220 return tunables
222 def _schedule_new_optimizer_suggestions(self) -> bool:
223 """
224 Optimizer part of the loop.
226 Load the results of the executed trials into the optimizer, suggest new
227 configurations, and add them to the queue. Return True if optimization is not
228 over, False otherwise.
229 """
230 assert self.experiment is not None
231 (trial_ids, configs, scores, status) = self.experiment.load(self._last_trial_id)
232 _LOG.info("QUEUE: Update the optimizer with trial results: %s", trial_ids)
233 self.optimizer.bulk_register(configs, scores, status)
234 self._last_trial_id = max(trial_ids, default=self._last_trial_id)
236 not_done = self.not_done()
237 if not_done:
238 tunables = self.optimizer.suggest()
239 self.schedule_trial(tunables)
241 return not_done
243 def schedule_trial(self, tunables: TunableGroups) -> None:
244 """Add a configuration to the queue of trials."""
245 for repeat_i in range(1, self._trial_config_repeat_count + 1):
246 self._add_trial_to_queue(
247 tunables,
248 config={
249 # Add some additional metadata to track for the trial such as the
250 # optimizer config used.
251 # Note: these values are unfortunately mutable at the moment.
252 # Consider them as hints of what the config was the trial *started*.
253 # It is possible that the experiment configs were changed
254 # between resuming the experiment (since that is not currently
255 # prevented).
256 "optimizer": self.optimizer.name,
257 "repeat_i": repeat_i,
258 "is_defaults": tunables.is_defaults(),
259 **{
260 f"opt_{key}_{i}": val
261 for (i, opt_target) in enumerate(self.optimizer.targets.items())
262 for (key, val) in zip(["target", "direction"], opt_target)
263 },
264 },
265 )
267 def _add_trial_to_queue(
268 self,
269 tunables: TunableGroups,
270 ts_start: Optional[datetime] = None,
271 config: Optional[Dict[str, Any]] = None,
272 ) -> None:
273 """
274 Add a configuration to the queue of trials.
276 A wrapper for the `Experiment.new_trial` method.
277 """
278 assert self.experiment is not None
279 trial = self.experiment.new_trial(tunables, ts_start, config)
280 _LOG.info("QUEUE: Add new trial: %s", trial)
282 def _run_schedule(self, running: bool = False) -> None:
283 """
284 Scheduler part of the loop.
286 Check for pending trials in the queue and run them.
287 """
288 assert self.experiment is not None
289 for trial in self.experiment.pending_trials(datetime.now(UTC), running=running):
290 self.run_trial(trial)
292 def not_done(self) -> bool:
293 """
294 Check the stopping conditions.
296 By default, stop when the optimizer converges or max limit of trials reached.
297 """
298 return self.optimizer.not_converged() and (
299 self._trial_count < self._max_trials or self._max_trials <= 0
300 )
302 @abstractmethod
303 def run_trial(self, trial: Storage.Trial) -> None:
304 """
305 Set up and run a single trial.
307 Save the results in the storage.
308 """
309 assert self.experiment is not None
310 self._trial_count += 1
311 self._ran_trials.append(trial)
312 _LOG.info("QUEUE: Execute trial # %d/%d :: %s", self._trial_count, self._max_trials, trial)
314 @property
315 def ran_trials(self) -> List[Storage.Trial]:
316 """Get the list of trials that were run."""
317 return self._ran_trials