Coverage for mlos_bench/mlos_bench/schedulers/trial_runner.py: 95%
82 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-21 01:50 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Simple class to run an individual Trial on a given Environment."""
7import logging
8from datetime import datetime
9from types import TracebackType
10from typing import Any, Literal
12from pytz import UTC
14from mlos_bench.environments.base_environment import Environment
15from mlos_bench.environments.status import Status
16from mlos_bench.event_loop_context import EventLoopContext
17from mlos_bench.services.base_service import Service
18from mlos_bench.services.config_persistence import ConfigPersistenceService
19from mlos_bench.services.local.local_exec import LocalExecService
20from mlos_bench.services.types import SupportsConfigLoading
21from mlos_bench.storage.base_storage import Storage
22from mlos_bench.tunables.tunable_groups import TunableGroups
24_LOG = logging.getLogger(__name__)
27class TrialRunner:
28 """
29 Simple class to help run an individual Trial on an environment.
31 TrialRunner manages the lifecycle of a single trial, including setup, run, teardown,
32 and async status polling via EventLoopContext background threads.
34 Multiple TrialRunners can be used in a multi-processing pool to run multiple trials
35 in parallel, for instance.
36 """
38 @classmethod
39 def create_from_json(
40 cls,
41 *,
42 config_loader: Service,
43 env_json: str,
44 svcs_json: str | list[str] | None = None,
45 num_trial_runners: int = 1,
46 tunable_groups: TunableGroups | None = None,
47 global_config: dict[str, Any] | None = None,
48 ) -> list["TrialRunner"]:
49 # pylint: disable=too-many-arguments
50 """
51 Create a list of TrialRunner instances, and their associated Environments and
52 Services, from JSON configurations.
54 Since each TrialRunner instance is independent, they can be run in parallel,
55 and hence must each get their own copy of the Environment and Services to
56 operate on.
58 The global_config is shared across all TrialRunners, but each copy gets its
59 own unique trial_runner_id.
61 Parameters
62 ----------
63 config_loader : Service
64 A service instance capable of loading configuration (i.e., SupportsConfigLoading).
65 env_json : str
66 JSON file or string representing the environment configuration.
67 svcs_json : str | list[str] | None
68 JSON file(s) or string(s) representing the Services configuration.
69 num_trial_runners : int
70 Number of TrialRunner instances to create. Default is 1.
71 tunable_groups : TunableGroups | None
72 TunableGroups instance to use as the parent Tunables for the
73 environment. Default is None.
74 global_config : dict[str, Any] | None
75 Global configuration parameters. Default is None.
77 Returns
78 -------
79 list[TrialRunner]
80 A list of TrialRunner instances created from the provided configuration.
81 """
82 assert isinstance(config_loader, SupportsConfigLoading)
83 svcs_json = svcs_json or []
84 tunable_groups = tunable_groups or TunableGroups()
85 global_config = global_config or {}
86 trial_runners: list[TrialRunner] = []
87 for trial_runner_id in range(1, num_trial_runners + 1): # use 1-based indexing
88 # Make a fresh Environment and Services copy for each TrialRunner.
89 # Give each global_config copy its own unique trial_runner_id.
90 # This is important in case multiple TrialRunners are running in parallel.
91 global_config_copy = global_config.copy()
92 global_config_copy["trial_runner_id"] = trial_runner_id
93 # Each Environment's parent service starts with at least a
94 # LocalExecService in addition to the ConfigLoader.
95 parent_service: Service = ConfigPersistenceService(
96 config={"config_path": config_loader.get_config_paths()},
97 global_config=global_config_copy,
98 )
99 parent_service = LocalExecService(parent=parent_service)
100 parent_service = config_loader.load_services(
101 svcs_json,
102 global_config_copy,
103 parent_service,
104 )
105 env = config_loader.load_environment(
106 env_json,
107 tunable_groups.copy(),
108 global_config_copy,
109 service=parent_service,
110 )
111 trial_runners.append(TrialRunner(trial_runner_id, env))
112 return trial_runners
114 def __init__(self, trial_runner_id: int, env: Environment) -> None:
115 self._trial_runner_id = trial_runner_id
116 self._env = env
117 assert self._env.parameters["trial_runner_id"] == self._trial_runner_id
118 self._in_context = False
119 self._is_running = False
120 self._event_loop_context = EventLoopContext()
122 def __repr__(self) -> str:
123 return (
124 f"TrialRunner({self.trial_runner_id}, {repr(self.environment)}"
125 f"""[trial_runner_id={self.environment.parameters.get("trial_runner_id")}])"""
126 )
128 def __str__(self) -> str:
129 return f"TrialRunner({self.trial_runner_id}, {str(self.environment)})"
131 @property
132 def trial_runner_id(self) -> int:
133 """Get the TrialRunner's id."""
134 return self._trial_runner_id
136 @property
137 def environment(self) -> Environment:
138 """Get the Environment."""
139 return self._env
141 def __enter__(self) -> "TrialRunner":
142 assert not self._in_context
143 _LOG.debug("TrialRunner START :: %s", self)
144 # TODO: self._event_loop_context.enter()
145 self._env.__enter__()
146 self._in_context = True
147 return self
149 def __exit__(
150 self,
151 ex_type: type[BaseException] | None,
152 ex_val: BaseException | None,
153 ex_tb: TracebackType | None,
154 ) -> Literal[False]:
155 assert self._in_context
156 _LOG.debug("TrialRunner END :: %s", self)
157 self._env.__exit__(ex_type, ex_val, ex_tb)
158 # TODO: self._event_loop_context.exit()
159 self._in_context = False
160 return False # Do not suppress exceptions
162 @property
163 def is_running(self) -> bool:
164 """Get the running state of the current TrialRunner."""
165 return self._is_running
167 def run_trial(
168 self,
169 trial: Storage.Trial,
170 global_config: dict[str, Any] | None = None,
171 ) -> None:
172 """
173 Run a single trial on this TrialRunner's Environment and stores the results in
174 the backend Trial Storage.
176 Parameters
177 ----------
178 trial : Storage.Trial
179 A Storage class based Trial used to persist the experiment trial data.
180 global_config : dict
181 Global configuration parameters.
183 Returns
184 -------
185 (trial_status, trial_score) : (Status, dict[str, float] | None)
186 Status and results of the trial.
187 """
188 assert self._in_context
190 assert not self._is_running
191 self._is_running = True
193 assert trial.trial_runner_id == self.trial_runner_id, (
194 f"TrialRunner {self} should not run trial {trial} "
195 f"with different trial_runner_id {trial.trial_runner_id}."
196 )
198 if not self.environment.setup(trial.tunables, trial.config(global_config)):
199 _LOG.warning("Setup failed: %s :: %s", self.environment, trial.tunables)
200 # FIXME: Use the actual timestamp from the environment.
201 _LOG.info("TrialRunner: Update trial results: %s :: %s", trial, Status.FAILED)
202 trial.update(Status.FAILED, datetime.now(UTC))
203 return
205 # TODO: start background status polling of the environments in the event loop.
207 # Block and wait for the final result.
208 (status, timestamp, results) = self.environment.run()
209 _LOG.info("TrialRunner Results: %s :: %s\n%s", trial.tunables, status, results)
211 # In async mode (TODO), poll the environment for status and telemetry
212 # and update the storage with the intermediate results.
213 (_status, _timestamp, telemetry) = self.environment.status()
215 # Use the status and timestamp from `.run()` as it is the final status of the experiment.
216 # TODO: Use the `.status()` output in async mode.
217 trial.update_telemetry(status, timestamp, telemetry)
219 trial.update(status, timestamp, results)
220 _LOG.info("TrialRunner: Update trial results: %s :: %s %s", trial, status, results)
222 self._is_running = False
224 def teardown(self) -> None:
225 """
226 Tear down the Environment.
228 Call it after the completion of one (or more) `.run()` in the TrialRunner
229 context.
230 """
231 assert self._in_context
232 self._env.teardown()