Coverage for mlos_bench/mlos_bench/schedulers/trial_runner.py: 95%

82 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-21 01:50 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Simple class to run an individual Trial on a given Environment.""" 

6 

7import logging 

8from datetime import datetime 

9from types import TracebackType 

10from typing import Any, Literal 

11 

12from pytz import UTC 

13 

14from mlos_bench.environments.base_environment import Environment 

15from mlos_bench.environments.status import Status 

16from mlos_bench.event_loop_context import EventLoopContext 

17from mlos_bench.services.base_service import Service 

18from mlos_bench.services.config_persistence import ConfigPersistenceService 

19from mlos_bench.services.local.local_exec import LocalExecService 

20from mlos_bench.services.types import SupportsConfigLoading 

21from mlos_bench.storage.base_storage import Storage 

22from mlos_bench.tunables.tunable_groups import TunableGroups 

23 

24_LOG = logging.getLogger(__name__) 

25 

26 

27class TrialRunner: 

28 """ 

29 Simple class to help run an individual Trial on an environment. 

30 

31 TrialRunner manages the lifecycle of a single trial, including setup, run, teardown, 

32 and async status polling via EventLoopContext background threads. 

33 

34 Multiple TrialRunners can be used in a multi-processing pool to run multiple trials 

35 in parallel, for instance. 

36 """ 

37 

38 @classmethod 

39 def create_from_json( 

40 cls, 

41 *, 

42 config_loader: Service, 

43 env_json: str, 

44 svcs_json: str | list[str] | None = None, 

45 num_trial_runners: int = 1, 

46 tunable_groups: TunableGroups | None = None, 

47 global_config: dict[str, Any] | None = None, 

48 ) -> list["TrialRunner"]: 

49 # pylint: disable=too-many-arguments 

50 """ 

51 Create a list of TrialRunner instances, and their associated Environments and 

52 Services, from JSON configurations. 

53 

54 Since each TrialRunner instance is independent, they can be run in parallel, 

55 and hence must each get their own copy of the Environment and Services to 

56 operate on. 

57 

58 The global_config is shared across all TrialRunners, but each copy gets its 

59 own unique trial_runner_id. 

60 

61 Parameters 

62 ---------- 

63 config_loader : Service 

64 A service instance capable of loading configuration (i.e., SupportsConfigLoading). 

65 env_json : str 

66 JSON file or string representing the environment configuration. 

67 svcs_json : str | list[str] | None 

68 JSON file(s) or string(s) representing the Services configuration. 

69 num_trial_runners : int 

70 Number of TrialRunner instances to create. Default is 1. 

71 tunable_groups : TunableGroups | None 

72 TunableGroups instance to use as the parent Tunables for the 

73 environment. Default is None. 

74 global_config : dict[str, Any] | None 

75 Global configuration parameters. Default is None. 

76 

77 Returns 

78 ------- 

79 list[TrialRunner] 

80 A list of TrialRunner instances created from the provided configuration. 

81 """ 

82 assert isinstance(config_loader, SupportsConfigLoading) 

83 svcs_json = svcs_json or [] 

84 tunable_groups = tunable_groups or TunableGroups() 

85 global_config = global_config or {} 

86 trial_runners: list[TrialRunner] = [] 

87 for trial_runner_id in range(1, num_trial_runners + 1): # use 1-based indexing 

88 # Make a fresh Environment and Services copy for each TrialRunner. 

89 # Give each global_config copy its own unique trial_runner_id. 

90 # This is important in case multiple TrialRunners are running in parallel. 

91 global_config_copy = global_config.copy() 

92 global_config_copy["trial_runner_id"] = trial_runner_id 

93 # Each Environment's parent service starts with at least a 

94 # LocalExecService in addition to the ConfigLoader. 

95 parent_service: Service = ConfigPersistenceService( 

96 config={"config_path": config_loader.get_config_paths()}, 

97 global_config=global_config_copy, 

98 ) 

99 parent_service = LocalExecService(parent=parent_service) 

100 parent_service = config_loader.load_services( 

101 svcs_json, 

102 global_config_copy, 

103 parent_service, 

104 ) 

105 env = config_loader.load_environment( 

106 env_json, 

107 tunable_groups.copy(), 

108 global_config_copy, 

109 service=parent_service, 

110 ) 

111 trial_runners.append(TrialRunner(trial_runner_id, env)) 

112 return trial_runners 

113 

114 def __init__(self, trial_runner_id: int, env: Environment) -> None: 

115 self._trial_runner_id = trial_runner_id 

116 self._env = env 

117 assert self._env.parameters["trial_runner_id"] == self._trial_runner_id 

118 self._in_context = False 

119 self._is_running = False 

120 self._event_loop_context = EventLoopContext() 

121 

122 def __repr__(self) -> str: 

123 return ( 

124 f"TrialRunner({self.trial_runner_id}, {repr(self.environment)}" 

125 f"""[trial_runner_id={self.environment.parameters.get("trial_runner_id")}])""" 

126 ) 

127 

128 def __str__(self) -> str: 

129 return f"TrialRunner({self.trial_runner_id}, {str(self.environment)})" 

130 

131 @property 

132 def trial_runner_id(self) -> int: 

133 """Get the TrialRunner's id.""" 

134 return self._trial_runner_id 

135 

136 @property 

137 def environment(self) -> Environment: 

138 """Get the Environment.""" 

139 return self._env 

140 

141 def __enter__(self) -> "TrialRunner": 

142 assert not self._in_context 

143 _LOG.debug("TrialRunner START :: %s", self) 

144 # TODO: self._event_loop_context.enter() 

145 self._env.__enter__() 

146 self._in_context = True 

147 return self 

148 

149 def __exit__( 

150 self, 

151 ex_type: type[BaseException] | None, 

152 ex_val: BaseException | None, 

153 ex_tb: TracebackType | None, 

154 ) -> Literal[False]: 

155 assert self._in_context 

156 _LOG.debug("TrialRunner END :: %s", self) 

157 self._env.__exit__(ex_type, ex_val, ex_tb) 

158 # TODO: self._event_loop_context.exit() 

159 self._in_context = False 

160 return False # Do not suppress exceptions 

161 

162 @property 

163 def is_running(self) -> bool: 

164 """Get the running state of the current TrialRunner.""" 

165 return self._is_running 

166 

167 def run_trial( 

168 self, 

169 trial: Storage.Trial, 

170 global_config: dict[str, Any] | None = None, 

171 ) -> None: 

172 """ 

173 Run a single trial on this TrialRunner's Environment and stores the results in 

174 the backend Trial Storage. 

175 

176 Parameters 

177 ---------- 

178 trial : Storage.Trial 

179 A Storage class based Trial used to persist the experiment trial data. 

180 global_config : dict 

181 Global configuration parameters. 

182 

183 Returns 

184 ------- 

185 (trial_status, trial_score) : (Status, dict[str, float] | None) 

186 Status and results of the trial. 

187 """ 

188 assert self._in_context 

189 

190 assert not self._is_running 

191 self._is_running = True 

192 

193 assert trial.trial_runner_id == self.trial_runner_id, ( 

194 f"TrialRunner {self} should not run trial {trial} " 

195 f"with different trial_runner_id {trial.trial_runner_id}." 

196 ) 

197 

198 if not self.environment.setup(trial.tunables, trial.config(global_config)): 

199 _LOG.warning("Setup failed: %s :: %s", self.environment, trial.tunables) 

200 # FIXME: Use the actual timestamp from the environment. 

201 _LOG.info("TrialRunner: Update trial results: %s :: %s", trial, Status.FAILED) 

202 trial.update(Status.FAILED, datetime.now(UTC)) 

203 return 

204 

205 # TODO: start background status polling of the environments in the event loop. 

206 

207 # Block and wait for the final result. 

208 (status, timestamp, results) = self.environment.run() 

209 _LOG.info("TrialRunner Results: %s :: %s\n%s", trial.tunables, status, results) 

210 

211 # In async mode (TODO), poll the environment for status and telemetry 

212 # and update the storage with the intermediate results. 

213 (_status, _timestamp, telemetry) = self.environment.status() 

214 

215 # Use the status and timestamp from `.run()` as it is the final status of the experiment. 

216 # TODO: Use the `.status()` output in async mode. 

217 trial.update_telemetry(status, timestamp, telemetry) 

218 

219 trial.update(status, timestamp, results) 

220 _LOG.info("TrialRunner: Update trial results: %s :: %s %s", trial, status, results) 

221 

222 self._is_running = False 

223 

224 def teardown(self) -> None: 

225 """ 

226 Tear down the Environment. 

227 

228 Call it after the completion of one (or more) `.run()` in the TrialRunner 

229 context. 

230 """ 

231 assert self._in_context 

232 self._env.teardown()